You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/shape_inference.md
+6Lines changed: 6 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -65,6 +65,8 @@ When `use_padding` is true, the offset is chosen to preserve dimensionality (i.e
65
65
66
66
The shape and projection inference handles `Conv_input` terms differently depending on `use_padding`. If `use_padding` is false, the impact of convolution kernels is incorporated additively during shape inference, and there's nothing more to do during projection inference. If `use_padding` is true, convolution kernels don't contribute during shape inference, and padding is computed during projection inference, keyed by `proj_id`. Padding is maximal size (width) of dilated kernels as encountered in `Dim` - `Conv_input` constraints and is propagated in either direction, although in practice for CNNs `Conv_input` should only appear as `subr` of `Dim_ineq`.
67
67
68
+
Shape inference does not maintain padding for axes of individual tensor nodes, these padding values are computed and updated during projections inference.
69
+
68
70
### Inference strategy
69
71
70
72
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature. The subtyping stems merely from the fact that a dimension-1 axis can be used in the context of any dimension due to per-axis broadcasting. Row polymorphism stems from broadcasting to more axes: for example, when unifying an unknown (shape) row with a known one, we cannot assume that the unknown row will have just the axes of the known one, because maybe the known row is meant to be broadcasted here to more axes. The combination of row polymorphism with nominal subtyping means that the constraints we are solving are inequalities, both inequalities between rows (the `Row.t` type, i.e. the `row` type above), and between axes/dimensions (the `Row.dim` type). We maintain the inequality ordering between variables in the environment to compute the transitive closure during simplification. We also maintain a least upper bound on the solution.
@@ -210,6 +212,10 @@ The projection inference functions.
210
212
*`solve_proj_equations` unifies the projection equations, using union-find to maintain a representative for equal projections. Projections that already have an `axis_index` are `non_product` (not to be iterated over). The remaining projections have a `product_dim`, and get a fresh iterator.
211
213
*`get_dim_index` gets an `axis_index` for a `dim` based on the representative of its `proj_id`; and `Fixed_idx 0` for dim=1.
212
214
215
+
### Convolutions
216
+
217
+
There is an important and intentional disconnect between `dims` in the `arrayjit` part of the project: tensor nodes, `Ndarray` buffers, code generation: they include padding in the dimension sizes -- and on the other hand shape types, shape inference and tensors exclude padding from the dimension sizes. There is a tension: once the delayed computations of padding, projections and dims (dimension sizes) are forced for a particular node, the padding can no longer be updated (the underlying `Ndarray` buffer might already be created). Since during inference we update the padding incrementally without variables standing in for insufficient information, this unfortunately causes observability of the during-inference and post-inference distinction for the padding of a tensor node.
0 commit comments