Skip to content

Commit 2766112

Browse files
committed
Fixes #326: Fix the wrongly implied assumption that in einsum spec, axes of omitted kind get broadcasted or reduced
1 parent b066a3b commit 2766112

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

lib/shape.mli

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
- separators_with_comma: commas and whitespaces containing at least one comma.
1616
- axes_spec_single_char: separators? identifier+ separators?
1717
- axes_spec_multichar: separators? (identifier separators_with_comma)* identifier separators?
18-
- conv_expression: term '+' term where term is [coeff '*'] identifier and coeff is integer
18+
- conv_expression: term '+' term
19+
- term: (coeff '*')? identifier
20+
- coeff: integer -- note that syntax extensions will splice in the value of an OCaml identifier
1921
- ellipsis_spec: '...' <|> '..' identifier '..'
2022
- row_spec: axes_spec <|> ellipsis_spec axes_spec <|> axes_spec ellipsis_spec axes_spec
2123
- labels_spec: row_spec <|> row_spec '|' row_spec <|> row_spec '->' row_spec <|> row_spec '|'

lib/syntax_extensions.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,16 +344,18 @@ The syntax of an axis spec:
344344

345345
Examples:
346346

347-
- `...|...->... => 0`, `...|... => 0` and `... => 0` are equivalent: reduce all axes of the argument into a single number. Useful e.g. for reducing losses to a single number.
347+
- `...|...->... => 0`: reduce all axes of the argument into a single number. Useful e.g. for reducing losses to a single number.
348+
- `...|... => 0`, `...->... => 0`, `... => 0` do the same but will fail if the argument has axes of the kind for which the ellipsis is missing.
348349
- `...|...->... => ...|...->...`: fully pointwise unary operation.
350+
- `...->... => ...->...`, `...|... => ...|...`, `... => ...`: fully pointwise but will fail if the argument has axes of the kind for which the ellipsis is missing.
349351
- `...|...->... ; ...|...->... => ...|...->...`: fully pointwise binary operation.
350-
- `...|...->... => ...->...` and `...->... => ...->...` are equivalent: reduce the batch axes into the result.
352+
- `...|...->... => ...->...`: reduce the batch axes into the result.
351353
- `2...|...->... => ...|...->...`: slice the tensor at dimension 2 of the leftmost batch axis. Note that the tensor operation `@|` implements slicing at the leftmost batch axis for arbitrary dimension.
352-
- `...|... => ...|...2`: expand the tensor by putting the argument at leftmost output dimension 2 of the result (and reduce input axes if any). `rhs ++ "...|... => ...|...2"` will fill the other cells of the new tensor with zeroes; `[%cd lhs =:* rhs ~logic:"...|... => ...|...2"]` will fill the other cells of `lhs` with ones since it's the neutral element of the assignment (reduction) operator.
353-
- `ijk => kji`: reverse the three rightmost output axes, reduce any other axes.
354+
- `...|... => ...|...2`: expand the tensor by putting the argument at leftmost output dimension 2 of the result (and reduce input axes if any). `rhs ++ "...|... => ...|...2"` will fill the other cells of the new tensor with zeroes; `[%cd lhs =:* rhs ~logic:"...|... => ...|...2"]` will fill the other cells of `lhs` with ones since it's the neutral element of the assignment (reduction) operator, here with ones.
355+
- `ijk => kji`: reverse the three output axes, fails if the argument has any other axes.
354356
- `ijk => ki`: as above but also reduce the second-leftmost output axis.
355-
- `..v..|ijk => ..v..kji`: reverse the three rightmost output axes, reduce any other output and input axes, pointwise for batch axes, pairing the batch axes with the leftmost output axes of the result.
356-
- `2..v..|... => ..v..`: slice the tensor at dimension 2 of the leftmost batch axis, reduce all its input and output axes, preserve its other batch axes as output axes.
357+
- `..v..|...ijk => ..v..kji`: reverse the three rightmost output axes, reduce any other output axes, pointwise for batch axes, pairing the batch axes with the leftmost output axes of the result. Fails if the argument has input axes.
358+
- `2..v..|... => ..v..`: slice the tensor at dimension 2 of the leftmost batch axis, reduce all its output axes, preserve its other batch axes as output axes. Fails if the argument has input axes.
357359

358360
## Further features of the syntax extension %cd
359361

@@ -413,7 +415,7 @@ If you recall, inline declared param tensors get lifted out of functions except
413415

414416
```ocaml
415417
let mlp_layer ~config =
416-
let w = TDSL.param "w" and b = TDSL.param ~output_dims:[ config.hid_dim ] in
418+
let w = TDSL.param "w" and b = TDSL.param ~output_dims:[ config.hid_dim ] "b" in
417419
fun x -> TDSL.O.(w * x + b)
418420
```
419421

@@ -519,4 +521,4 @@ type comp = {
519521
}
520522
```
521523

522-
The tensor nodes that are in `asgns` but not in `embedded_nodes`, and are on-device, must already be present in contexts with which the computation is linked. Such non-embedded nodes can be seen as inputs to the computation -- except that for `backprop` code of a tensor, they are actually the outputs! Embedded nodes are closely related to _rootness_ -- when a node has not been used in the code of another tensor, it is a root (a forward root for value nodes and a backprop root for grad nodes). `embedded_nodes` were roots the first time they were used in `asgns`.
524+
The tensor nodes that are in `asgns` but not in `embedded_nodes`, and are on-device, must already be present in contexts with which the computation is linked. Such non-embedded nodes can be seen as inputs to the computation -- except that for `backprop` code of a tensor, they are actually the outputs! Embedded nodes are closely related to _rootness_ -- when a node has not been used in the code of another tensor, it is a root (a forward root for value nodes and a backprop root for grad nodes). `embedded_nodes` were roots the first time they were used in `asgns`. Parameters, as created by `Tensor.param`, are not embedded in the code that uses them and thus will not be in `embedded_nodes` of the forward and backprop code over the parameters; however, they will constitute the `embedded_nodes` of the `Tensor.init_params` code.

0 commit comments

Comments
 (0)