You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Requires extending expressivity of projections and the generalized einsum notation.
71
+
* Then, we can add convnet building blocks and corresponding examples starting with MNIST.
72
+
* Verify or rethink usefulness of dimension labels, and whether to introduce axis labels.
73
+
* 0.7: Replicate the scaffolding from [llm.c](https://github.com/karpathy/llm.c) for training GPT-2.
70
74
* Useful building blocks for models in [lib/nn_blocks.ml](lib/nn_blocks.ml).
71
75
* A language model example.
72
76
* Port (translate or bind) the Python files from [llm.c](https://github.com/karpathy/llm.c) to implement tokenization, data loading and saving etc.
73
77
* At the end of 0.6.x, we should have an apples-to-apples benchmark comparing OCANNL to [llm.c](https://github.com/karpathy/llm.c) for both CPU and GPU.
74
-
* 0.7: Optimize performance -- low hanging fruit.
78
+
* 0.8: Optimize performance -- low hanging fruit.
75
79
* First harvested from [Fast Multidimensional Matrix Multiplication on CPU from Scratch](https://siboehm.com/articles/22/Fast-MMM-on-CPU).
76
80
* Then harvested from [How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog](https://siboehm.com/articles/22/CUDA-MMM).
77
81
* Finally from [llm.c](https://github.com/karpathy/llm.c).
78
82
* These will require splitting a routine into multiple CUDA kernels.
79
-
* 0.8: A new abstraction layer automating compilation/linking, execution, and some data transfers.
83
+
* 0.9: A new abstraction layer automating compilation/linking, execution, and some data transfers.
80
84
* E.g. host-device transfers: copy from host if host update is later than the previous device update.
81
85
* Concise syntax for transfers into the merge buffer since we know which tensor node is transferred and where to.
Copy file name to clipboardExpand all lines: lib/shape_inference.md
+22-24Lines changed: 22 additions & 24 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -112,7 +112,7 @@ type logic =
112
112
113
113
### Non-tensor-like constraints
114
114
115
-
The above mechanisms (excluding `dim_constraint` and `row_constraint`) are sufficient to express tensor applications such as inner and outer products, axis permutations. They cannot directly express: size constraints, fixed position indexing (except for the special case of position 0), axis concatenation and "reverse concatenation" / splitting, strides, convolutions. At present, we implement size constraints and fixed position indexing.
115
+
The above mechanisms (excluding `dim_constraint` and `row_constraint`) are sufficient to express tensor applications such as inner and outer products, axis permutations. Axis concatenation and "reverse concatenation" / splitting is handled by the representation above via the "product" `Prod` dimension constructor. The above mechanisms cannot directly express: size constraints, fixed position indexing (except for the special case of position 0), strides, convolutions. At present, we implement size constraints and fixed position indexing.
116
116
117
117
```ocaml
118
118
type dim_constraint = Unconstrained_dim | At_least_dim of int
@@ -125,13 +125,33 @@ type row_constraint =
125
125
126
126
During the solution process, the constraints are incorporated, or propagated, into the environment `constr` entry fields, and into further `constraint_` constraints, as needed. This provides sufficient scaffolding to implement the other complex constraints as the need arises.
127
127
128
+
### Product dimensions (Prod)
129
+
130
+
The `Prod` construct represents an axis that is a product of other axes. This can be used to model:
131
+
132
+
1.**Concatenation and splitting**: Multiple axes concatenated into a single axis or a single axis split into multiple as part of an operation.
133
+
2.**Multi-axis views**: Treating multiple axes as a single flattened axis and vice-versa.
134
+
135
+
For a `Prod [d1; d2; ...; dn]`:
136
+
137
+
* The dimension is the product of all constituent dimensions: `dim(d1) × dim(d2) × ... × dim(dn)`
138
+
* The projection respects the order of axes, implementing a row-major indexing scheme
139
+
* During inference, constraints on the product propagate to constraints on the constituents
140
+
* In the einsum notation, product axes will be denoted using `&`, e.g., `i&j` represents a single axis that is the product of axes `i` and `j`
141
+
142
+
Product dimensions interact with other shape inference features:
143
+
144
+
***Broadcasting**: A Prod dimension can be broadcasted if its constituents are dimension-1
***Constraints**: An `At_least_dim` constraint on a Prod propagates to its constituents
147
+
128
148
## Solving the constraints
129
149
130
150
The constraints are solved by: unification of the equation constraints, unification-like simplification of the inequality constraints, propagation of the complex constraints. Simplification of an inequality, and constraint propagation, can generate more constraints, so we need to be careful to keep it terminating. The solution proceeds in stages.
131
151
132
152
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
133
153
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
134
-
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 5 to all variables.)
154
+
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.)
135
155
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any, otherwise to the lower bound.
136
156
* Stage 5 addresses `Total_elems` constraints with yet-unknown row variables. If the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension.
137
157
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality.
@@ -187,25 +207,3 @@ Other important functions in the `Shape` module.
187
207
*`finish_inference` is called right before some projections or array dimensions are required (typically, because of jitting). It performs a second round of `propagate_shapes`, and then once again attempts to solve any remaining constraints that `propagate_shapes` didn't solve. Then it "closes the shapes": substitutes out remaining shape variables by their LUBs if any, or dimension-1 / `Broadcastable` (no-more-axes). Then it resets the environment state, since the shapes are now guaranteed to not have variables.
188
208
*`derive_projections` starts by freshening the `proj_id`s in the `update_step`. Then it generates and solves shape inequalities, and then generates and solves projection equations, and constructs the `projections` record.
189
209
*`of_spec` constructs a shape record from an einsum slot spec. If `deduced = Input_equals_output`, it adds the corresponding equation to the global environment.
190
-
191
-
### Product dimensions (Prod)
192
-
193
-
The `Prod` construct represents an axis that is a product of other axes. This can be used to model:
194
-
195
-
1.**Concatenation**: Multiple axes concatenated into a single axis
196
-
2.**Multi-axis views**: Treating multiple axes as a single flattened axis
197
-
198
-
For a `Prod [d1; d2; ...; dn]`:
199
-
200
-
* The dimension is the product of all constituent dimensions: `dim(d1) × dim(d2) × ... × dim(dn)`
201
-
* The projection respects the order of axes, implementing a row-major indexing scheme
202
-
* During inference, constraints on the product propagate to constraints on the constituents
203
-
* In the einsum notation, product axes will be denoted using `&`, e.g., `i&j` represents a single axis that is the product of axes `i` and `j`
204
-
205
-
Product dimensions interact with other shape inference features:
206
-
207
-
***Broadcasting**: A Prod dimension can be broadcasted if all its constituents are dimension-1
0 commit comments