Skip to content

Commit 75e8fbc

Browse files
committed
Fixes #410; in progress: refine the shape inference to treat dim-1 with label the same as dim>1, only dim-1 without label is different (more general)
1 parent c60cff6 commit 75e8fbc

File tree

4 files changed

+98
-85
lines changed

4 files changed

+98
-85
lines changed

docs/shape_inference.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
To separate concerns, OCANNL is split into the `arrayjit` library, responsible for compilation of high-level n-D array operation sequences (`Assignments.comp`) via backends such as sync_cc, metal and cuda, and the main `ocannl` library, responsible for deriving the operations computing the forward propagation and backpropagation from tensor expressions. In particular, `arrayjit` contains `Indexing`, which represents complex indexing into arrays, and the main library `ocannl` has `Row` and `Shape` modules, which do the most "heavy-lifting" in the translation from concise tensor expressions to sequences of assignments.
44

5-
Shape inference broadly speaking consists in OCANNL of inferring the `Shape.t` record -- shape inference proper, and inferring the `Indexing.projections` record -- projections inference. `Shape.t` records are mutable, so that the partially inferred shapes can be observed by the user. Shape and projections inference is intended to be declarative -- independent of the order in which constraints are added. There is one aspect that is not declarative: when tensor expressions are compiled to assignments, i.e. jitted, still-unsolved shape variables in terminal nodes are substituted by their least upper bounds if any, or by dimension-1 / no-more-axes.
5+
Shape inference broadly speaking consists in OCANNL of inferring the `Shape.t` record -- shape inference proper, and inferring the `Indexing.projections` record -- projections inference. `Shape.t` records are mutable, so that the partially inferred shapes can be observed by the user. Shape and projections inference is intended to be declarative -- independent of the order in which constraints are added. There is one aspect that is not declarative: when tensor expressions are compiled to assignments, i.e. jitted, still-unsolved shape variables in terminal nodes are substituted by their least upper bounds if any, or by dimension-1 (no label) / no-more-axes.
66

77
The bulk of the projections inference happens alongside shape inference, with the projections-relevant information stored in auxiliary fields -- this prevents subtle bugs where projection semantics deviates from shape semantics, and will simplify adding new shape/projection inference features. Shape inference happens during `propagate_shapes` calls, and then again in a `finish_inference` call, which is triggered whenever the dimensions or projections are required (i.e. typically by jitting). Finally, the projections are reconstructed in `derive_projections`. It would seem `derive_projections` could reuse the already-computed solutions constraints. But we face a problem: we must prevent contaminating projections across different operations. To illustrate: we conclude the dimensions of two axes are the same because they are reduced together in another operation -- this should not force the axes to share a projection in the processed operation. To prevent the contamination, in each `derive_projections` call, we freshen the projection ids in the (inferred) shapes, and regenerate and re-solve the constraints with the fresh projection ids.
88

@@ -63,7 +63,7 @@ Shape inference does not maintain padding for axes of individual tensor nodes, t
6363

6464
### Preventing Premature Guessing with Total_elems Constraints
6565

66-
A critical aspect of shape inference is avoiding premature "guessing" of dimension variables to minimal values (dimension-1 or no-further-axes for rows) when such guessing would make pending constraints unsatisfiable. This is particularly important for `Total_elems` constraints of the form:
66+
A critical aspect of shape inference is avoiding premature "guessing" of dimension variables to minimal values (dimension-1-no-label or no-further-axes for rows) when such guessing would make pending constraints unsatisfiable. This is particularly important for `Total_elems` constraints of the form:
6767

6868
```ocaml
6969
Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by }
@@ -89,7 +89,7 @@ This mechanism ensures that `Total_elems` constraints with stride-based numerato
8989

9090
### Inference strategy
9191

92-
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature. The subtyping stems merely from the fact that a dimension-1 axis can be used in the context of any dimension due to per-axis broadcasting. Row polymorphism stems from broadcasting to more axes: for example, when unifying an unknown (shape) row with a known one, we cannot assume that the unknown row will have just the axes of the known one, because maybe the known row is meant to be broadcasted here to more axes. The combination of row polymorphism with nominal subtyping means that the constraints we are solving are inequalities, both inequalities between rows (the `Row.t` type, i.e. the `row` type above), and between axes/dimensions (the `Row.dim` type). We maintain the inequality ordering between variables in the environment to compute the transitive closure during simplification. We also maintain a least upper bound on the solution.
92+
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature. The subtyping stems merely from the fact that a dimension-1-no-label axis can be used in the context of any dimension due to per-axis broadcasting. Row polymorphism stems from broadcasting to more axes: for example, when unifying an unknown (shape) row with a known one, we cannot assume that the unknown row will have just the axes of the known one, because maybe the known row is meant to be broadcasted here to more axes. The combination of row polymorphism with nominal subtyping means that the constraints we are solving are inequalities, both inequalities between rows (the `Row.t` type, i.e. the `row` type above), and between axes/dimensions (the `Row.dim` type). We maintain the inequality ordering between variables in the environment to compute the transitive closure during simplification. We also maintain a least upper bound on the solution.
9393

9494
```ocaml
9595
type dim_entry =
@@ -205,7 +205,7 @@ During the solution process, the constraints are incorporated, or propagated, in
205205

206206
## Solving the constraints
207207

208-
The constraints are solved by: unification of the equation constraints, unification-like simplification of the inequality constraints, propagation of the complex constraints. The inequalities are like in type systems combining parametric polymorphism with structural and nominal subtyping, where the nominal subtyping relation states that dimension-1 axis is smaller than all axes, and axes of other dimensions are incomparable. For rows, the subtyping is suffix-wise (shorter is smaller) and axis-wise.
208+
The constraints are solved by: unification of the equation constraints, unification-like simplification of the inequality constraints, propagation of the complex constraints. The inequalities are like in type systems combining parametric polymorphism with structural and nominal subtyping, where the nominal subtyping relation states that dimension-1 without a label axis is smaller than all axes, and axes of other mismatching dimensions or mismatching labels are incomparable. For rows, the subtyping is suffix-wise (shorter is smaller) and axis-wise.
209209

210210
Simplification of an inequality, and constraint propagation, can generate more constraints, so we need to be careful to keep it terminating. The solution proceeds in stages. Currently there are 8 stages, with a fractional stage coming from splitting an earlier design.
211211

tensor/row.ml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,8 +1816,8 @@ let%track5_sexp solve_dim_ineq ~(stage : stage) origin ~(cur : dim) ~(subr : dim
18161816
@@ Shape_error
18171817
("dimension comparison for axis: different labels", [ Dim_mismatch [ cur; subr ] ])
18181818
| Dim { d = d1; _ }, Dim { d = d2; _ } when d1 = d2 -> ([], env)
1819-
| _, Dim { d = 1; _ } -> ([], env)
1820-
| (Dim { d = 1; _ } as cur), _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin } ], env)
1819+
| _, Dim { d = 1; label = None; _ } -> ([], env)
1820+
| (Dim { d = 1; label = None; _ } as cur), _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin } ], env)
18211821
| Conv_input _, _ | _, Conv_input _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin } ], env)
18221822
| Var cur_v, Var subr_v -> (
18231823
match (find_dim env.dim_env cur_v, find_dim env.dim_env subr_v) with
@@ -2488,9 +2488,16 @@ let%debug5_sexp solve_row_ineq ~(stage : stage) origin ~(cur : t) ~(subr : t) en
24882488
List.map2_exn (take_from_end r_cur.dims lub_len) (take_from_end lub2.dims lub_len)
24892489
~f:(fun d1 d2 ->
24902490
match (d1, d2) with
2491-
| Dim { d = 1; _ }, _ -> d1
2492-
| _, Dim { d = 1; _ } -> d2
2491+
(* Prefer dimensions without labels (more general), then prefer d=1 (more general
2492+
size) *)
2493+
| Dim { d = 1; label = None; _ }, _ -> d1
2494+
| _, Dim { d = 1; label = None; _ } -> d2
2495+
| Dim { d = 1; label = Some _; _ }, Dim { label = None; _ } -> d2
2496+
| Dim { label = None; _ }, Dim { d = 1; label = Some _; _ } -> d1
24932497
| Dim { d = d1; _ }, Dim { d = d2; _ } when d1 <> d2 -> get_dim ~d:1 ~proj_id:48 ()
2498+
| Dim { label = Some l1; _ }, Dim { label = Some l2; _ }
2499+
when not (String.equal l1 l2) ->
2500+
get_dim ~d:1 ~proj_id:63 ()
24942501
| Conv_input { stride; output = Dim s; _ }, Dim s'
24952502
| Dim s', Conv_input { stride; output = Dim s; _ }
24962503
when !use_padding && stride * s.d <> s'.d ->

tensor/shape.ml

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ let axis_map_to_dims_bio (type a) ?(default : a option) (idcs : a axis_map) =
328328
let back_axes, front_axes =
329329
Map.to_alist axes
330330
|> List.partition_map ~f:(fun ({ AxisKey.from_end; pos = i; _ }, v) ->
331-
if from_end then Either.First (i, v) else Second (i, v))
331+
if from_end then Either.First (i, v) else Second (i, v))
332332
in
333333
let back_size = List.fold back_axes ~init:0 ~f:(fun accu (i, _) -> max i accu) in
334334
let front_size = List.fold front_axes ~init:0 ~f:(fun accu (i, _) -> max i accu) in
@@ -888,21 +888,21 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
888888
( proj_axis_env,
889889
(Option.to_list static_range
890890
|> List.map ~f:(fun range ->
891-
Dim_eq
892-
{
893-
d1 = get_dim ~d:range ();
894-
d2 = slice_var;
895-
origin =
896-
[
897-
{
898-
lhs_name = sh.debug_name;
899-
lhs_kind = `Batch;
900-
rhs_name = Idx.symbol_ident static_symbol;
901-
rhs_kind = `Batch;
902-
operation = Some "Slice";
903-
};
904-
];
905-
}))
891+
Dim_eq
892+
{
893+
d1 = get_dim ~d:range ();
894+
d2 = slice_var;
895+
origin =
896+
[
897+
{
898+
lhs_name = sh.debug_name;
899+
lhs_kind = `Batch;
900+
rhs_name = Idx.symbol_ident static_symbol;
901+
rhs_kind = `Batch;
902+
operation = Some "Slice";
903+
};
904+
];
905+
}))
906906
@ [
907907
Row_eq { r1 = expanded_batch; r2 = sh.batch; origin = get_origin `Batch };
908908
Row_eq { r1 = cur_sh.input; r2 = sh.input; origin = get_origin `Input };
@@ -1832,12 +1832,21 @@ let%debug4_sexp derive_projections (update_step : update_step) : Idx.projections
18321832
let make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
18331833
?(deduced = Not_constrained) ~debug_name ~id () =
18341834
let open Row in
1835+
let known_no_batch =
1836+
match (batch_dims, batch_axes) with Some [], None -> true | None, Some [] -> true | _ -> false
1837+
in
1838+
let num_dim1_output = Option.to_list output_dims |> List.join |> List.count ~f:(fun d -> d = 1) in
1839+
let f kind d =
1840+
match kind with
1841+
| `Batch | `Input -> get_dim ~d ()
1842+
| `Output ->
1843+
if not known_no_batch && num_dim1_output = 1 && d = 1 then
1844+
let label = debug_name ^ "_output" in
1845+
get_dim ~d ~label ()
1846+
else get_dim ~d ()
1847+
in
18351848
let make_dims kind ds =
1836-
{
1837-
dims = List.map ~f:(fun d -> get_dim ~d ()) ds;
1838-
bcast = Broadcastable;
1839-
prov = provenance ~sh_id:id ~kind;
1840-
}
1849+
{ dims = List.map ~f:(f kind) ds; bcast = Broadcastable; prov = provenance ~sh_id:id ~kind }
18411850
in
18421851
let make_axes kind ds =
18431852
{
@@ -1987,15 +1996,12 @@ let to_string_hum ?(style = Row.Axis_size) (sh : t) =
19871996
let dims = (row_of_kind kind sh).dims in
19881997
String.concat ~sep:","
19891998
@@ List.mapi dims ~f:(fun i d ->
1990-
let num =
1991-
match kind with
1992-
| `Input -> n_batch + n_outputs + i
1993-
| `Output -> n_batch + i
1994-
| `Batch -> i
1995-
in
1996-
match style with
1997-
| Row.Only_labels | Axis_size | Projection_and_size -> Row.dim_to_string style d
1998-
| Axis_number_and_size -> Int.to_string num ^ ":" ^ Row.dim_to_string style d)
1999+
let num =
2000+
match kind with `Input -> n_batch + n_outputs + i | `Output -> n_batch + i | `Batch -> i
2001+
in
2002+
match style with
2003+
| Row.Only_labels | Axis_size | Projection_and_size -> Row.dim_to_string style d
2004+
| Axis_number_and_size -> Int.to_string num ^ ":" ^ Row.dim_to_string style d)
19992005
in
20002006
let batch_dims = dims_to_string `Batch in
20012007
let input_dims = dims_to_string `Input in

0 commit comments

Comments
 (0)