You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/shape_inference.md
+4-4Lines changed: 4 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,7 +2,7 @@
2
2
3
3
To separate concerns, OCANNL is split into the `arrayjit` library, responsible for compilation of high-level n-D array operation sequences (`Assignments.comp`) via backends such as sync_cc, metal and cuda, and the main `ocannl` library, responsible for deriving the operations computing the forward propagation and backpropagation from tensor expressions. In particular, `arrayjit` contains `Indexing`, which represents complex indexing into arrays, and the main library `ocannl` has `Row` and `Shape` modules, which do the most "heavy-lifting" in the translation from concise tensor expressions to sequences of assignments.
4
4
5
-
Shape inference broadly speaking consists in OCANNL of inferring the `Shape.t` record -- shape inference proper, and inferring the `Indexing.projections` record -- projections inference. `Shape.t` records are mutable, so that the partially inferred shapes can be observed by the user. Shape and projections inference is intended to be declarative -- independent of the order in which constraints are added. There is one aspect that is not declarative: when tensor expressions are compiled to assignments, i.e. jitted, still-unsolved shape variables in terminal nodes are substituted by their least upper bounds if any, or by dimension-1 / no-more-axes.
5
+
Shape inference broadly speaking consists in OCANNL of inferring the `Shape.t` record -- shape inference proper, and inferring the `Indexing.projections` record -- projections inference. `Shape.t` records are mutable, so that the partially inferred shapes can be observed by the user. Shape and projections inference is intended to be declarative -- independent of the order in which constraints are added. There is one aspect that is not declarative: when tensor expressions are compiled to assignments, i.e. jitted, still-unsolved shape variables in terminal nodes are substituted by their least upper bounds if any, or by dimension-1 (no label) / no-more-axes.
6
6
7
7
The bulk of the projections inference happens alongside shape inference, with the projections-relevant information stored in auxiliary fields -- this prevents subtle bugs where projection semantics deviates from shape semantics, and will simplify adding new shape/projection inference features. Shape inference happens during `propagate_shapes` calls, and then again in a `finish_inference` call, which is triggered whenever the dimensions or projections are required (i.e. typically by jitting). Finally, the projections are reconstructed in `derive_projections`. It would seem `derive_projections` could reuse the already-computed solutions constraints. But we face a problem: we must prevent contaminating projections across different operations. To illustrate: we conclude the dimensions of two axes are the same because they are reduced together in another operation -- this should not force the axes to share a projection in the processed operation. To prevent the contamination, in each `derive_projections` call, we freshen the projection ids in the (inferred) shapes, and regenerate and re-solve the constraints with the fresh projection ids.
8
8
@@ -63,7 +63,7 @@ Shape inference does not maintain padding for axes of individual tensor nodes, t
63
63
64
64
### Preventing Premature Guessing with Total_elems Constraints
65
65
66
-
A critical aspect of shape inference is avoiding premature "guessing" of dimension variables to minimal values (dimension-1 or no-further-axes for rows) when such guessing would make pending constraints unsatisfiable. This is particularly important for `Total_elems` constraints of the form:
66
+
A critical aspect of shape inference is avoiding premature "guessing" of dimension variables to minimal values (dimension-1-no-label or no-further-axes for rows) when such guessing would make pending constraints unsatisfiable. This is particularly important for `Total_elems` constraints of the form:
@@ -89,7 +89,7 @@ This mechanism ensures that `Total_elems` constraints with stride-based numerato
89
89
90
90
### Inference strategy
91
91
92
-
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature. The subtyping stems merely from the fact that a dimension-1 axis can be used in the context of any dimension due to per-axis broadcasting. Row polymorphism stems from broadcasting to more axes: for example, when unifying an unknown (shape) row with a known one, we cannot assume that the unknown row will have just the axes of the known one, because maybe the known row is meant to be broadcasted here to more axes. The combination of row polymorphism with nominal subtyping means that the constraints we are solving are inequalities, both inequalities between rows (the `Row.t` type, i.e. the `row` type above), and between axes/dimensions (the `Row.dim` type). We maintain the inequality ordering between variables in the environment to compute the transitive closure during simplification. We also maintain a least upper bound on the solution.
92
+
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature. The subtyping stems merely from the fact that a dimension-1-no-label axis can be used in the context of any dimension due to per-axis broadcasting. Row polymorphism stems from broadcasting to more axes: for example, when unifying an unknown (shape) row with a known one, we cannot assume that the unknown row will have just the axes of the known one, because maybe the known row is meant to be broadcasted here to more axes. The combination of row polymorphism with nominal subtyping means that the constraints we are solving are inequalities, both inequalities between rows (the `Row.t` type, i.e. the `row` type above), and between axes/dimensions (the `Row.dim` type). We maintain the inequality ordering between variables in the environment to compute the transitive closure during simplification. We also maintain a least upper bound on the solution.
93
93
94
94
```ocaml
95
95
type dim_entry =
@@ -205,7 +205,7 @@ During the solution process, the constraints are incorporated, or propagated, in
205
205
206
206
## Solving the constraints
207
207
208
-
The constraints are solved by: unification of the equation constraints, unification-like simplification of the inequality constraints, propagation of the complex constraints. The inequalities are like in type systems combining parametric polymorphism with structural and nominal subtyping, where the nominal subtyping relation states that dimension-1 axis is smaller than all axes, and axes of other dimensions are incomparable. For rows, the subtyping is suffix-wise (shorter is smaller) and axis-wise.
208
+
The constraints are solved by: unification of the equation constraints, unification-like simplification of the inequality constraints, propagation of the complex constraints. The inequalities are like in type systems combining parametric polymorphism with structural and nominal subtyping, where the nominal subtyping relation states that dimension-1 without a label axis is smaller than all axes, and axes of other mismatching dimensions or mismatching labels are incomparable. For rows, the subtyping is suffix-wise (shorter is smaller) and axis-wise.
209
209
210
210
Simplification of an inequality, and constraint propagation, can generate more constraints, so we need to be careful to keep it terminating. The solution proceeds in stages. Currently there are 8 stages, with a fractional stage coming from splitting an earlier design.
0 commit comments