Skip to content

Commit 1015b65

Browse files
lukstaficlaude
andcommitted
Update shape_inference.md for Affine type and stride_offset semantics
- Update type definition from Conv_input to Affine with convolution option - Replace "Convolution-based indexing" section with "Affine indexing and convolutions" - Document the key insight: stride_offset is projection-time only - Add dimension formulas with derivation from max input index - Clarify projection inference behavior for use_padding true/false 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent faab4f3 commit 1015b65

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

docs/shape_inference.md

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ A row is a sequence of axes of a single kind: batch, input, or output. The shape
1818
type solved_dim = { d : int; label : string option; proj_id : proj_id option }
1919
(** A single axis in a shape. *)
2020
21+
type convolution = { dilation : int; kernel : dim; use_padding : bool }
22+
(** Convolution parameters. *)
23+
2124
type dim =
2225
| Var of dim_var
2326
| Dim of solved_dim
24-
| Conv_input of { stride : int; output : dim; dilation : int; kernel : dim }
25-
(** The offset is implicit, automatically derived. If [!use_padding] is [true], the offset is
26-
the left part of the dimensionality-preserving symmetric padding, otherwise it is 0. If
27-
[!use_padding] is [true], the value stands for dimensions size [stride * output],
28-
otherwise for dimensions size [stride * output + dilation * kernel]. If [dilation = 0],
29-
the value stands for projections of strided iteration rather than convolution. *)
27+
| Affine of { stride : int; over : dim; conv : convolution option; stride_offset : int }
28+
(** An affine transformation of a dimension. When [conv] is [None], this is strided
29+
iteration. When [conv] is [Some], this includes convolution parameters.
30+
Invariants: [stride > 0], [dilation > 0] (when present), [0 <= stride_offset < stride]. *)
3031
3132
type bcast =
3233
| Row_var of row_var (** The row can be inferred to have more axes. *)
@@ -47,19 +48,37 @@ The actual implementation is split into the `Row` module, which handles multi-ro
4748

4849
Labels are a part of OCANNL, but it's a topic that needs more exploration and future work. Currently, OCANNL has labeled dimensions, but not labeled axes. This means that when two axes need to agree on the number of dimensions, they also need to agree on the labels. If the dimensions of both axes have labels, the labels need to be the same, and if one doesn't have a label initially, it's inferred to have the label from the other axis. Intuitively, the label is a specification of the semantics of an axis that is more fine-grained than, but of similar nature as, the number of dimensions. Currently, there is no global check to prevent the same label be used with different numbers of dimensions (on unrelated axes). Example: a label `"rgb"` attached to dimensions size 3 to denote that an axis represents three channels "red", "green" and "blue".
4950

50-
### Convolution-based indexing
51+
### Affine indexing and convolutions
52+
53+
The `Affine` constructor represents affine transformations of dimensions, supporting both strided iteration and convolution-style indexing. The key insight is that `stride_offset` is purely a projection-time concern—it selects which elements within each stride window are accessed, but does not affect dimension sizes. This allows the same input tensor to be used with different `stride_offset` values (0 to stride-1) without shape errors.
54+
55+
**Dimension formulas** (during shape inference):
5156

52-
The `Conv_input` constructor represents convolution-style input dimensions that enable support for operations like convolutions where output indices relate to input indices through the relationship:
57+
- `conv = None` (strided iteration): `input_dim = stride * output_dim`
58+
- `conv = Some { use_padding = true; _ }`: `input_dim = stride * output_dim` (padding compensates for kernel)
59+
- `conv = Some { use_padding = false; dilation; kernel }`: `input_dim = stride * output_dim + dilation * (kernel_dim - 1)`
60+
61+
The formula for `use_padding = false` comes from computing the maximum input index. For output position `o` (0 to output_dim-1) and kernel position `j` (0 to kernel_dim-1), the input index is:
5362

5463
```
55-
input_dimension = stride * output_iterator + dilation * kernel_iterator
64+
input_index = o * stride + stride_offset + j * dilation
5665
```
5766

58-
When `use_padding` is true, the offset is chosen to preserve dimensionality (i.e., output size equals input size for stride=1). When false, the offset is 0 (no padding).
67+
The maximum input index occurs at `o = output_dim - 1`, `j = kernel_dim - 1`, and `stride_offset = stride - 1` (worst case for any valid stride_offset):
68+
69+
```
70+
max_index = (output_dim - 1) * stride + (stride - 1) + (kernel_dim - 1) * dilation
71+
= output_dim * stride - 1 + (kernel_dim - 1) * dilation
72+
```
73+
74+
Thus `input_dim = max_index + 1 = output_dim * stride + (kernel_dim - 1) * dilation`.
75+
76+
**Projection inference** handles `Affine` terms differently depending on `use_padding`:
5977

60-
The shape and projection inference handles `Conv_input` terms differently depending on `use_padding`. If `use_padding` is false, the impact of convolution kernels is incorporated additively during shape inference, and there's nothing more to do during projection inference. If `use_padding` is true, convolution kernels don't contribute during shape inference, and padding is computed during projection inference, keyed by `proj_id`. Padding is maximal size (width) of dilated kernels as encountered in `Dim` - `Conv_input` constraints and is propagated in either direction, although in practice for CNNs `Conv_input` should only appear as `subr` of `Dim_ineq`.
78+
- If `use_padding = false`: the kernel extent `dilation * (kernel_dim - 1)` is incorporated during shape inference; projection inference just applies the index formula with the actual `stride_offset`.
79+
- If `use_padding = true`: kernel extent doesn't affect shape inference; padding is computed during projection inference, keyed by `proj_id`. The padding is the maximal dilated kernel extent encountered in constraints involving that projection id.
6180

62-
Shape inference does not maintain padding for axes of individual tensor nodes, these padding values are computed and updated during projections inference.
81+
Shape inference does not maintain padding for axes of individual tensor nodesthese padding values are computed and updated during projections inference.
6382

6483
### Preventing Premature Guessing with Total_elems Constraints
6584

0 commit comments

Comments
 (0)