Skip to content

Commit 584311a

Browse files
committed
More fixes and polish for the shapes&einsum slides
1 parent 0690a1b commit 584311a

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

docs/slides-shapes_and_einsum.md

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,15 @@ let%op attention ~num_heads () x =
141141
(* Matrix multiplication on individual output axes *)
142142
let%op matmul a b = a +* "ik; kj => ij" b
143143
144-
(* Batch matrix multiply on individual output axes with broadcasting *)
144+
(* Batch matrix multiply on individual output axes
145+
with broadcasting *)
145146
let%op batch_matmul a b =
146147
a +* "... | ik; ... | kj => ... | ij" b
147148
148149
(* Full tensor multiplication, equivalent to [a * b] *)
149150
let%op tensor_mul a b =
150-
a +* "... | ..k..->..i..; ... | ..j..->..k.. => ... | ..j..->..i.." b
151+
a +* "... | ..k..->..i..; ... | ..j..->..k..
152+
=> ... | ..j..->..i.." b
151153
152154
(* Max pooling, requiring specifically 4 output axes *)
153155
let%op max_reduce_ouput x = x @^^ "bhwc => b00c"
@@ -167,6 +169,8 @@ let%op max_reduce x = x @^^ "...|hwc => ...|00c"
167169
> ```
168170
> "stride*output + dilation*kernel"
169171
> ```
172+
>
173+
> Within the syntax extensions, `stride` and `dilation` can be identifiers of `int` values, in addition to integer literals.
170174
171175
{pause}
172176
@@ -253,12 +257,13 @@ You can programmatically create the spec for use with the dedicated syntaxes, bu
253257
```ocaml
254258
(* Reduce last N output dimensions, PyTorch-style keepdim *)
255259
let%op reduce_last_n ~n ?(keepdim = true) () =
256-
let vars = List.init n ~f:(fun i ->
257-
Char.to_string (Char.of_int_exn (97 + i))) in
260+
let vars =
261+
[%oc List.init n ~f:(fun i ->
262+
Char.to_string (Char.of_int_exn (97 + i)))] in
258263
let result_dims =
259-
if keepdim then String.make n '0' else "" in
260-
let spec = "... | ..." ^ String.concat "" vars ^
261-
" => ... | ..." ^ result_dims in
264+
[%oc if keepdim then String.make n '0' else ""] in
265+
let spec = [%oc "... | ..." ^ String.concat "" vars ^
266+
" => ... | ..." ^ result_dims] in
262267
fun x -> x ++ spec
263268
264269
(* Example: reduce_last_n ~n:3 ~keepdim:true ()
@@ -268,6 +273,10 @@ let%op reduce_last_n ~n ?(keepdim = true) () =
268273
generates: "... | ...abc => ... | ..." *)
269274
```
270275

276+
{pause}
277+
278+
The `[%oc ...]` syntax allows embedding arbitrary OCaml code without `%op` attempting to interpret things as tensors.
279+
271280
{pause up}
272281
## Practical Patterns
273282

@@ -352,7 +361,7 @@ a +* "ij; jk => ik" b
352361
## Tips and Tricks
353362

354363
{#tips .block title="Best Practices"}
355-
> 1. **Use `|` for axis kinds** when mixing batch/input/output
364+
> 1. **Use `|`, `->` for axis kinds** when there's a meaningful batch/input/output split
356365
> 2. **Add trailing comma** for multi-char mode: `"input->output,"`
357366
> 3. **Avoid over-capturing** dimensions in einsum specs
358367
> 4. **Remember tensor operators**: `*` (matmul) vs `*.` (pointwise)
@@ -363,12 +372,17 @@ a +* "ij; jk => ik" b
363372
{#debugging .remark title="Debugging Shapes"}
364373
> When shapes don't match:
365374
>
366-
> * Print tensor shapes: `Tensor.print ~force:true tensor`
375+
> * Print tensor shapes: `Tensor.print ~force:true tensor`
376+
> [but not before all relevant tensor expressions are constructed]{.unrevealed #premature-inference-finalize}
367377
> * Check axis kinds are correctly specified
368378
> * Verify broadcasting assumptions
369379
> * Use explicit dimension constraints when needed
370380
371-
{pause up}
381+
{pause reveal=premature-inference-finalize}
382+
383+
{pause focus=premature-inference-finalize}
384+
385+
{pause unfocus up}
372386
## Common Pitfalls
373387

374388
**Tensor operators matter:**
@@ -475,4 +489,4 @@ Check [nn_blocks.ml](https://github.com/ahrefs/ocannl/blob/master/lib/nn_blocks.
475489
* [doc/syntax_extensions.md](syntax_extensions.html) - Full `%op` and `%cd` syntax
476490
* [lib/shape.mli](../dev/neural_nets_lib/Ocannl/Shape/index.html) - Shape inference internals
477491
* [lib/nn_blocks.ml](https://github.com/ahrefs/ocannl/blob/master/lib/nn_blocks.ml#L68) - Production examples
478-
* `test/einsum_trivia.ml` - Einsum test cases
492+
* [test/einsum_trivia.ml](https://github.com/ahrefs/ocannl/blob/master/test/einsum/einsum_trivia.ml) - Einsum test cases

0 commit comments

Comments
 (0)