You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/shape_inference.md
+2-3Lines changed: 2 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -126,8 +126,7 @@ type logic =
126
126
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of [s1],
127
127
hence the name. *)
128
128
| Terminal of [ `Data of Ir.Assignments.init_data | `Fetch of Ir.Assignments.fetch_op ]
129
-
(** Extracts any available shape information from the initialization. E.g.
130
-
for [File_mapped fn], opens the file [fn] to check its length. *)
129
+
(** Extracts any available shape information from the initialization, e.g. the number of elements. *)
131
130
```
132
131
133
132
### Non-tensor-like constraints
@@ -208,7 +207,7 @@ There is an important and intentional difference between `dims` in the `arrayjit
208
207
Other important functions in the `Shape` module.
209
208
210
209
*`einsum_slot_spec_to_dims_bio ~generative` parses an einsum spec for a single shape, returns the three rows and a mapping from axis (`dim`) variables to indices where the einsum specifies fixed indexing. When `generative` is true for the kind of a row, when an axis has a fixed projection to dimension 0, the axis is not a variable added to the fixed indexing mapping, but is instead dimension-1 (solved). The "generative" rows are the ones with no initial user-provided shape information. This is just a heuristic to avoid surprises where a tensor axis with only dimension 0 populated gets inferred a bigger dimension size -- it might be revisited in the future.
211
-
*`get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values` and `File_mapped (filename, prec)` (where the file is scanned to get its length). For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape ≥ lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
210
+
*`get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values`. For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape ≥ lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
212
211
*`propagate_shapes` gets and then solves the inequalities, using a global state for the environment. It udpates the shapes in-place with the partial solution. It is invoked twice for each `update_step`: first during the bottom-up process of building tensors, and then in reverse order from `finish_inference`.
213
212
*`finish_inference` is called right before some projections or array dimensions are required (typically, because of jitting). It performs a second round of `propagate_shapes`, and then once again attempts to solve any remaining constraints that `propagate_shapes` didn't solve. Then it "closes the shapes": substitutes out remaining shape variables by their LUBs if any, or dimension-1 / `Broadcastable` (no-more-axes). Then it resets the environment state, since the shapes are now guaranteed to not have variables.
214
213
*`derive_projections` starts by freshening the `proj_id`s in the `update_step`. Then it generates and solves shape inequalities, and then generates and solves projection equations, and constructs the `projections` record.
0 commit comments