Skip to content

Commit 07092c0

Browse files
committed
Some more progress on concatenation-along-axes before we give up
1 parent 4dabf5b commit 07092c0

File tree

4 files changed

+31
-39
lines changed

4 files changed

+31
-39
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,24 +66,24 @@ NOTE: debug logging from CUDA in complex settings is a bit tricky, it involves a
6666

6767
This is very tentative.
6868

69-
* 0.6: Replicate the scaffolding from [llm.c](https://github.com/karpathy/llm.c) for training GPT-2.
69+
* 0.6: Hopefully-efficient expressivity: concatenation and splitting, convolution, maybe block tensors.
70+
* Requires extending expressivity of projections and the generalized einsum notation.
71+
* Then, we can add convnet building blocks and corresponding examples starting with MNIST.
72+
* Verify or rethink usefulness of dimension labels, and whether to introduce axis labels.
73+
* 0.7: Replicate the scaffolding from [llm.c](https://github.com/karpathy/llm.c) for training GPT-2.
7074
* Useful building blocks for models in [lib/nn_blocks.ml](lib/nn_blocks.ml).
7175
* A language model example.
7276
* Port (translate or bind) the Python files from [llm.c](https://github.com/karpathy/llm.c) to implement tokenization, data loading and saving etc.
7377
* At the end of 0.6.x, we should have an apples-to-apples benchmark comparing OCANNL to [llm.c](https://github.com/karpathy/llm.c) for both CPU and GPU.
74-
* 0.7: Optimize performance -- low hanging fruit.
78+
* 0.8: Optimize performance -- low hanging fruit.
7579
* First harvested from [Fast Multidimensional Matrix Multiplication on CPU from Scratch](https://siboehm.com/articles/22/Fast-MMM-on-CPU).
7680
* Then harvested from [How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog](https://siboehm.com/articles/22/CUDA-MMM).
7781
* Finally from [llm.c](https://github.com/karpathy/llm.c).
7882
* These will require splitting a routine into multiple CUDA kernels.
79-
* 0.8: A new abstraction layer automating compilation/linking, execution, and some data transfers.
83+
* 0.9: A new abstraction layer automating compilation/linking, execution, and some data transfers.
8084
* E.g. host-device transfers: copy from host if host update is later than the previous device update.
8185
* Concise syntax for transfers into the merge buffer since we know which tensor node is transferred and where to.
8286
* At the end of 0.8.x, OCANNL has a REPL.
83-
* 0.9: Hopefully-efficient expressivity: block tensors, convolution.
84-
* Requires extending expressivity of projections and the generalized einsum notation.
85-
* Then, we can add convnet building blocks and corresponding examples starting with MNIST.
86-
* Verify or rethink usefulness of dimension labels, and whether to introduce axis labels.
8787
* 0.10: Optimize performance: program search.
8888
* Instead of dynamic scheduling as in tinygrad, we can schedule statically by program search.
8989
* We should also reproduce the search that tinygrad is doing.

lib/row.ml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -875,12 +875,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
875875
@@ Shape_error
876876
( "Cannot compare Prod with unresolved variables in inequality",
877877
[ Dim_mismatch [ cur; subr ] ] )
878-
| Prod _, Var _ | Var _, Prod _ ->
879-
(* Similar to above - we need all dimensions resolved to compare *)
880-
raise
881-
@@ Shape_error
882-
("Cannot compare Prod with variables in inequality", [ Dim_mismatch [ cur; subr ] ])
883-
| Var cur_v, Var subr_v -> (
878+
| Var cur_v, Var subr_v -> (
884879
match (Map.find env.dim_env cur_v, Map.find env.dim_env subr_v) with
885880
| Some (Bounds_dim { cur = cur1; _ }), _ when List.mem ~equal:equal_dim_var cur1 subr_v ->
886881
([ Dim_eq { d1 = cur; d2 = subr } ], env)
@@ -1055,7 +1050,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
10551050
Map.set env.dim_env ~key:subr_v
10561051
~data:(Bounds_dim { lub = Some cur; cur = cur2; subr = subr2; constr = constr2 });
10571052
} ))
1058-
| Var _, Dim _ (* when d2 > 1 *) -> ([ Dim_eq { d1 = cur; d2 = subr } ], env)
1053+
| Var _, (Dim _ (* when d2 > 1 *) | Prod _) -> ([ Dim_eq { d1 = cur; d2 = subr } ], env)
10591054
| Dim _, Dim _ ->
10601055
raise
10611056
@@ Shape_error ("dimension comparison for axis: mismatch", [ Dim_mismatch [ cur; subr ] ])

lib/row.mli

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ val get_dim : d:int -> ?label:string -> unit -> dim
2828
val dim_to_int_exn : dim -> int
2929
val dim_to_string : [> `Only_labels ] -> dim -> string
3030

31-
3231
(** Extracts all dimension variables from a dim, including from nested products. *)
3332
val dim_vars : dim -> dim_var list
3433

lib/shape_inference.md

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ type logic =
112112

113113
### Non-tensor-like constraints
114114

115-
The above mechanisms (excluding `dim_constraint` and `row_constraint`) are sufficient to express tensor applications such as inner and outer products, axis permutations. They cannot directly express: size constraints, fixed position indexing (except for the special case of position 0), axis concatenation and "reverse concatenation" / splitting, strides, convolutions. At present, we implement size constraints and fixed position indexing.
115+
The above mechanisms (excluding `dim_constraint` and `row_constraint`) are sufficient to express tensor applications such as inner and outer products, axis permutations. Axis concatenation and "reverse concatenation" / splitting is handled by the representation above via the "product" `Prod` dimension constructor. The above mechanisms cannot directly express: size constraints, fixed position indexing (except for the special case of position 0), strides, convolutions. At present, we implement size constraints and fixed position indexing.
116116

117117
```ocaml
118118
type dim_constraint = Unconstrained_dim | At_least_dim of int
@@ -125,13 +125,33 @@ type row_constraint =
125125

126126
During the solution process, the constraints are incorporated, or propagated, into the environment `constr` entry fields, and into further `constraint_` constraints, as needed. This provides sufficient scaffolding to implement the other complex constraints as the need arises.
127127

128+
### Product dimensions (Prod)
129+
130+
The `Prod` construct represents an axis that is a product of other axes. This can be used to model:
131+
132+
1. **Concatenation and splitting**: Multiple axes concatenated into a single axis or a single axis split into multiple as part of an operation.
133+
2. **Multi-axis views**: Treating multiple axes as a single flattened axis and vice-versa.
134+
135+
For a `Prod [d1; d2; ...; dn]`:
136+
137+
* The dimension is the product of all constituent dimensions: `dim(d1) × dim(d2) × ... × dim(dn)`
138+
* The projection respects the order of axes, implementing a row-major indexing scheme
139+
* During inference, constraints on the product propagate to constraints on the constituents
140+
* In the einsum notation, product axes will be denoted using `&`, e.g., `i&j` represents a single axis that is the product of axes `i` and `j`
141+
142+
Product dimensions interact with other shape inference features:
143+
144+
* **Broadcasting**: A Prod dimension can be broadcasted if its constituents are dimension-1
145+
* **Inequalities**: `Prod ds1 ≥ Prod ds2` requires compatible structures and element-wise inequalities
146+
* **Constraints**: An `At_least_dim` constraint on a Prod propagates to its constituents
147+
128148
## Solving the constraints
129149

130150
The constraints are solved by: unification of the equation constraints, unification-like simplification of the inequality constraints, propagation of the complex constraints. Simplification of an inequality, and constraint propagation, can generate more constraints, so we need to be careful to keep it terminating. The solution proceeds in stages.
131151

132152
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
133153
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
134-
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 5 to all variables.)
154+
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.)
135155
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any, otherwise to the lower bound.
136156
* Stage 5 addresses `Total_elems` constraints with yet-unknown row variables. If the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension.
137157
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality.
@@ -187,25 +207,3 @@ Other important functions in the `Shape` module.
187207
* `finish_inference` is called right before some projections or array dimensions are required (typically, because of jitting). It performs a second round of `propagate_shapes`, and then once again attempts to solve any remaining constraints that `propagate_shapes` didn't solve. Then it "closes the shapes": substitutes out remaining shape variables by their LUBs if any, or dimension-1 / `Broadcastable` (no-more-axes). Then it resets the environment state, since the shapes are now guaranteed to not have variables.
188208
* `derive_projections` starts by freshening the `proj_id`s in the `update_step`. Then it generates and solves shape inequalities, and then generates and solves projection equations, and constructs the `projections` record.
189209
* `of_spec` constructs a shape record from an einsum slot spec. If `deduced = Input_equals_output`, it adds the corresponding equation to the global environment.
190-
191-
### Product dimensions (Prod)
192-
193-
The `Prod` construct represents an axis that is a product of other axes. This can be used to model:
194-
195-
1. **Concatenation**: Multiple axes concatenated into a single axis
196-
2. **Multi-axis views**: Treating multiple axes as a single flattened axis
197-
198-
For a `Prod [d1; d2; ...; dn]`:
199-
200-
* The dimension is the product of all constituent dimensions: `dim(d1) × dim(d2) × ... × dim(dn)`
201-
* The projection respects the order of axes, implementing a row-major indexing scheme
202-
* During inference, constraints on the product propagate to constraints on the constituents
203-
* In the einsum notation, product axes will be denoted using `&`, e.g., `i&j` represents a single axis that is the product of axes `i` and `j`
204-
205-
Product dimensions interact with other shape inference features:
206-
207-
* **Broadcasting**: A Prod dimension can be broadcasted if all its constituents are dimension-1
208-
* **Inequalities**: `Prod ds1 ≥ Prod ds2` requires compatible structures and element-wise inequalities
209-
* **Constraints**: An `At_least_dim` constraint on a Prod propagates to its constituents
210-
211-
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature.

0 commit comments

Comments
 (0)