You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/shape_inference.md
+5-16Lines changed: 5 additions & 16 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -13,26 +13,15 @@ A tensor shape in OCANNL is composed of three rows of axes: batch, input and out
13
13
A row is a sequence of axes of a single kind: batch, input, or output. The shape type incorporates information relevant to inference, in particular shape variables: both for individual axes (`dim` variables), and for extending a row with more axes (`row` variables). Currently, all rows are (independently) broadcastable: can be broadcasted to a larger number of axes. However, in OCANNL the broadcasting can happen "in the middle", with not only the given trailing axes fixed, but also with the given leading axes fixed.
14
14
15
15
```ocaml
16
-
type solved_dim = {
17
-
d : int;
18
-
label : string option;
19
-
proj_id : proj_id option;
20
-
}
16
+
type solved_dim = { d : int; label : string option; proj_id : proj_id option }
21
17
(** A single axis in a shape. *)
22
18
23
19
type dim =
24
20
| Var of dim_var
25
21
| Dim of solved_dim
26
-
| Conv_input of {
27
-
stride : int;
28
-
output : dim;
29
-
solved_kernel : solved_dim option;
30
-
unsolved_kernel : (int * dim_var) list;
31
-
}
32
-
(** Represents convolution-style input dimensions where the output dimension
33
-
relates to the input dimension through: input = stride * output + kernel_terms.
34
-
This is a generalization of convolutions that supports affine indexing patterns.
35
-
The offset is implicit and depends on the global setting use_padding. *)
22
+
| Conv_input of { stride : int; output : dim; dilation : int; kernel : dim }
23
+
(** The offset is implicit, automatically derived. If [!use_padding] is [true], the offset is
24
+
the left part of the dimensionality-preserving symmetric padding, otherwise it is 0. *)
36
25
37
26
type bcast =
38
27
| Row_var of row_var (** The row can be inferred to have more axes. *)
@@ -214,7 +203,7 @@ The projection inference functions.
214
203
215
204
### Convolutions
216
205
217
-
There is an important and intentional disconnect between `dims` in the `arrayjit` part of the project: tensor nodes, `Ndarray` buffers, code generation: they include padding in the dimension sizes -- and on the other hand shape types, shape inference and tensors exclude padding from the dimension sizes. There is a tension: once the delayed computations of padding, projections and dims (dimension sizes) are forced for a particular node, the padding can no longer be updated (the underlying `Ndarray` buffer might already be created). Since during inference we update the padding incrementally without variables standing in for insufficient information, this unfortunately causes observability of the during-inference and post-inference distinction for the padding of a tensor node.
206
+
There is an important and intentional difference between `dims` in the `arrayjit` part of the project: tensor nodes, `Ndarray` buffers, code generation -- they include padding in the dimension sizes; and on the other hand shape types, shape inference and tensors exclude padding from the dimension sizes. There is a tension: once the delayed computations of padding, projections and dims (dimension sizes) are forced for a particular node, the padding can no longer be updated (the underlying `Ndarray` buffer might already be created). Since during inference we update the padding incrementally without variables standing in for insufficient information, this unfortunately causes observability of the during-inference and post-inference distinction for the padding of a tensor node.
0 commit comments