Skip to content

Commit ed5eac4

Browse files
committed
Fix the shape inference specification: we need to incorporate LUBs even for non-terminal shapes.
1 parent d822db0 commit ed5eac4

File tree

5 files changed

+39
-21
lines changed

5 files changed

+39
-21
lines changed

docs/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ Simplification of an inequality, and constraint propagation, can generate more c
214214
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds LUB (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes dimension variables in terminal shapes that do not have a LUB by dim 1, **except** when the variable has `has_uniq_constr_unless` set (indicating it's in the numerator of a `Total_elems` constraint) and none of the denominator variables are also prevented from guessing -- this prevents premature guessing that would make `Total_elems` constraints unsatisfiable. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint. In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock.
215215
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, or terminal ones, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
216216
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
217-
* Stage 6 sets row variables in the remaining inequalities and updated shapes to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
217+
* Stage 6 sets row variables in the remaining inequalities and updated shapes to no-further-axes values; it also extends LUB processing to non-terminal shapes. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
218218
* Stage 7 sets all dim variables remaining in updated shapes to the lower bound if they have any, otherwise to dimension-1.
219219

220220
Let's explain the shape inference functions.

tensor/row.ml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,19 +2795,24 @@ let%track5_sexp close_row_terminal ~(stage : stage) ~is_param origin env
27952795
[%log6 "terminal row: keeping", (_r : row), "as", (r1 : row)];
27962796
Terminal_row (is_param, r1, origin) :: term_dims ())
27972797

2798-
let%debug5_sexp eliminate_dim_entry ~final origin v ~lub constr =
2798+
let%debug5_sexp eliminate_dim_entry stage origin v ~lub constr =
27992799
match (lub, constr) with
28002800
| Some (Dim { d; _ } as lub), At_least_dim d2 when d2 > d ->
28012801
raise
28022802
@@ Shape_error
28032803
( [%string "dereferenced at dimension %{d2#Int}, higher than use site"],
28042804
[ Dim_mismatch [ lub; Var v ] ] )
28052805
| Some _, At_least_dim 0 ->
2806-
if final then Some (Dim_eq { d1 = Var v; d2 = get_dim ~d:1 ~proj_id:57 (); origin }) else None
2807-
| Some lub, At_least_dim _ -> Some (Dim_eq { d1 = Var v; d2 = lub; origin })
2808-
| None, At_least_dim d when final ->
2806+
(* Direct access at 0 is a strong heuristic for dimension 1 axis (e.g. result of a
2807+
reduction). *)
2808+
if is_stage7 stage then Some (Dim_eq { d1 = Var v; d2 = get_dim ~d:1 ~proj_id:57 (); origin })
2809+
else None
2810+
| Some lub, (At_least_dim _ | Unconstrained_dim) when is_stage6_up stage ->
2811+
Some (Dim_eq { d1 = Var v; d2 = lub; origin })
2812+
| None, At_least_dim d when is_stage7 stage ->
28092813
Some (Dim_eq { d1 = Var v; d2 = get_dim ~d ~proj_id:58 (); origin })
2810-
| _ when final -> Some (Dim_eq { d1 = Var v; d2 = get_dim ~d:1 ~proj_id:59 (); origin })
2814+
| None, _ when is_stage7 stage ->
2815+
Some (Dim_eq { d1 = Var v; d2 = get_dim ~d:1 ~proj_id:59 (); origin })
28112816
| _ -> None
28122817

28132818
let%track5_sexp process_shape_row ~(stage : stage) origin env ({ dims; bcast; prov } as r : row) :
@@ -2825,7 +2830,7 @@ let%track5_sexp process_shape_row ~(stage : stage) origin env ({ dims; bcast; pr
28252830
("You forgot to specify the hidden dimension(s) 5", [ Row_mismatch [ r ] ])
28262831
| Some (Bounds_dim { lub; constr; has_uniq_constr_unless; _ })
28272832
when is_stage4_up stage && can_guess_dim_to_one env has_uniq_constr_unless ->
2828-
Option.to_list @@ eliminate_dim_entry ~final origin v ~lub constr
2833+
Option.to_list @@ eliminate_dim_entry stage origin v ~lub constr
28292834
| Some (Solved_dim _) -> assert false
28302835
| Some (Bounds_dim { has_uniq_constr_unless; _ })
28312836
when final && can_guess_dim_to_one env has_uniq_constr_unless ->
@@ -2849,6 +2854,8 @@ let%track5_sexp process_shape_row ~(stage : stage) origin env ({ dims; bcast; pr
28492854
let dim_eqs = process_dims beg_dims @ process_dims dims in
28502855
let r1 : row = row_of_var v prov in
28512856
match find_row env.row_env v with
2857+
| Some (Bounds_row { lub = Some lub; constr = Unconstrained; _ }) when is_stage6_up stage ->
2858+
(Row_eq { r1; r2 = lub; origin } :: dim_eqs, env)
28522859
| Some (Bounds_row { constr = Unconstrained; _ }) when not final ->
28532860
(Shape_row (r, origin) :: dim_eqs, env)
28542861
| Some (Bounds_row { constr = Unconstrained; _ }) when final ->

test/operations/attention_test.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ Testing basic multi-head attention
44
Output shape:
55
((batch
66
((dims
7-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 204)))))
8-
(Dim ((d 8) (label ()) (proj_id ((Proj_id 205)))))))
7+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 210)))))
8+
(Dim ((d 8) (label ()) (proj_id ((Proj_id 211)))))))
99
(bcast Broadcastable) (prov (((sh_id 64) (kind Batch))))))
1010
(input ((dims ()) (bcast Broadcastable) (prov (((sh_id 64) (kind Input))))))
1111
(output
12-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 206)))))))
12+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 212)))))))
1313
(bcast Broadcastable) (prov (((sh_id 64) (kind Output))))))
1414
(batch_padding ()) (input_padding ()) (output_padding ()) (id 64)
1515
(debug_name output))

test/operations/layer_norm_test.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ Testing basic mini decoder model
44
Output shape:
55
((batch
66
((dims
7-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 418)))))
8-
(Dim ((d 8) (label ()) (proj_id ((Proj_id 419)))))))
7+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 424)))))
8+
(Dim ((d 8) (label ()) (proj_id ((Proj_id 425)))))))
99
(bcast Broadcastable) (prov (((sh_id 154) (kind Batch))))))
1010
(input
1111
((dims ()) (bcast Broadcastable) (prov (((sh_id 154) (kind Input))))))
1212
(output
13-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 420)))))))
13+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 426)))))))
1414
(bcast Broadcastable) (prov (((sh_id 154) (kind Output))))))
1515
(batch_padding ()) (input_padding ()) (output_padding ()) (id 154)
1616
(debug_name layer_norm))
Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
Retrieving commandline, environment, or config file variable ocannl_log_level
22
Found 0, in the config file
3-
Testing basic transformer model
4-
Output shape:
3+
Testing transformer with teacher forcing
4+
Loss shape:
5+
((batch
6+
((dims ()) (bcast Broadcastable) (prov (((sh_id 463) (kind Batch))))))
7+
(input
8+
((dims ()) (bcast Broadcastable) (prov (((sh_id 463) (kind Input))))))
9+
(output
10+
((dims ((Dim ((d 1) (label ()) (proj_id ((Proj_id 1262)))))))
11+
(bcast Broadcastable) (prov (((sh_id 463) (kind Output))))))
12+
(batch_padding ()) (input_padding ()) (output_padding ()) (id 463)
13+
(debug_name loss))
14+
Logits shape:
515
((batch
616
((dims
7-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 2194)))))
8-
(Dim ((d 8) (label ()) (proj_id ((Proj_id 2195)))))))
9-
(bcast Broadcastable) (prov (((sh_id 830) (kind Batch))))))
17+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 1212)))))
18+
(Dim ((d 7) (label ()) (proj_id ((Proj_id 1213)))))))
19+
(bcast Broadcastable) (prov (((sh_id 437) (kind Batch))))))
1020
(input
11-
((dims ()) (bcast Broadcastable) (prov (((sh_id 830) (kind Input))))))
21+
((dims ()) (bcast Broadcastable) (prov (((sh_id 437) (kind Input))))))
1222
(output
13-
((dims ()) (bcast Broadcastable) (prov (((sh_id 830) (kind Output))))))
14-
(batch_padding ()) (input_padding ()) (output_padding ()) (id 830)
23+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 1214)))))))
24+
(bcast Broadcastable) (prov (((sh_id 437) (kind Output))))))
25+
(batch_padding ()) (input_padding ()) (output_padding ()) (id 437)
1526
(debug_name transformer))

0 commit comments

Comments
 (0)