Skip to content

Commit ab15bd2

Browse files
committed
Multiple fixes to shape inference around closing terminal rows and Total_elems inference
1. Be careful about stages when eliminating row constraints. 2. Move elimination of non-LUB rows from stage 2 to stage 4 (and partially 3) since new LUBs can arrive by other inference. 3. Be more careful when eliminating Total_elems to not prevent incorporation of LUB values via adding no-further-axes, but also to not impose single-axis-dim-1 accidentally (since no-further-axes is also Total_elems 1).
1 parent e22b6f9 commit ab15bd2

File tree

3 files changed

+44
-42
lines changed

3 files changed

+44
-42
lines changed

bin/primitive_ops.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ let%debug_sexp graph_t () : unit =
2222
let ctx = Backend.make_context stream in
2323
let open Operation.At in
2424
CDSL.virtualize_settings.enable_device_only <- false;
25-
let%op f x = recip x in
25+
let%op f x = sin x in
2626
let size = 50 in
2727
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) + 0.1) in
2828
let x_flat =

lib/row.ml

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ let check_empty_row r =
822822
let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint) (env : environment)
823823
: constraint_ list * environment =
824824
match rows_to_row_or_vars rows with
825-
| Either.First single_row -> apply_row_constraint ~stage single_row constr env
825+
| Either.First single_row -> apply_row_constraint stage single_row constr env
826826
| Either.Second (all_dims, row_vars) -> (
827827
match constr with
828828
| Exact dims when List.length dims < List.length all_dims ->
@@ -897,7 +897,7 @@ let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint)
897897
)
898898
| _ -> ([ Rows_constr { r = rows; constr } ], env))
899899

900-
and apply_row_constraint ~stage (r : row) (constr : row_constraint) env : constraint_ list * _ =
900+
and apply_row_constraint stage (r : row) (constr : row_constraint) env : constraint_ list * _ =
901901
if is_unconstrained constr then ([], env)
902902
else
903903
let reduce constr ~beg_dims ~dims =
@@ -1418,7 +1418,7 @@ let%debug5_sexp rec unify_row ~stage (eq : t * t) (env : environment) :
14181418
List.iter subr ~f:(fun subr ->
14191419
ineqs := Row_ineq { subr = row_of_var subr value.id; cur = r2 } :: !ineqs);
14201420
Option.iter lub ~f:(fun lub -> ineqs := Row_ineq { cur = lub; subr = r2 } :: !ineqs);
1421-
let extras, env = apply_row_constraint ~stage value constr env in
1421+
let extras, env = apply_row_constraint stage value constr env in
14221422
ineqs := extras @ !ineqs;
14231423
env)
14241424
else env
@@ -1889,15 +1889,15 @@ let%debug5_sexp close_dim_terminal ~(stage : stage) (env : environment) (dim : d
18891889
transitively terminal. *)
18901890
[]
18911891

1892-
let last_dim_is dims d2 = match List.last dims with Some (Dim { d; _ }) -> d = d2 | _ -> false
1892+
let last_dim_is dims p = match List.last dims with Some (Dim { d; _ }) -> p d | _ -> false
18931893

18941894
let r_dims r =
18951895
match r.bcast with Broadcastable -> r.dims | Row_var { beg_dims; _ } -> beg_dims @ r.dims
18961896

1897-
let rec eliminate_rows_constraint ~lub (rows : row list) (constr : row_constraint)
1897+
let rec eliminate_rows_constraint stage ~lub (rows : row list) (constr : row_constraint)
18981898
(env : environment) : constraint_ list =
18991899
match rows_to_row_or_vars rows with
1900-
| Either.First single_row -> eliminate_row_constraint ~lub single_row constr env
1900+
| Either.First single_row -> eliminate_row_constraint stage ~lub single_row constr env
19011901
| Either.Second (_all_dims, row_vars) -> (
19021902
let rev_row_vars = List.rev row_vars in
19031903
match
@@ -1906,7 +1906,7 @@ let rec eliminate_rows_constraint ~lub (rows : row list) (constr : row_constrain
19061906
| _, { kind = `Output; _ } -> true
19071907
| _ -> false) )
19081908
with
1909-
| Total_elems _, Some (idx, (v, id)) ->
1909+
| Total_elems _, Some (idx, (v, id)) when is_stage5_up stage ->
19101910
let other_vars = List.filteri rev_row_vars ~f:(fun i _ -> i <> idx) in
19111911
let other_vars = List.map other_vars ~f:(fun (v, id) -> row_of_var v id) in
19121912
let other_eqs =
@@ -1918,16 +1918,18 @@ let rec eliminate_rows_constraint ~lub (rows : row list) (constr : row_constrain
19181918
| { bcast = Row_var { v = v'; _ }; _ } as r when equal_row_var v' v -> r
19191919
| r -> { r with dims = r_dims r; bcast = Broadcastable })
19201920
in
1921-
other_eqs @ eliminate_rows_constraint ~lub rows constr env
1921+
other_eqs @ eliminate_rows_constraint stage ~lub rows constr env
19221922
| _ -> [ Rows_constr { r = rows; constr } ])
19231923

1924-
and eliminate_row_constraint ~lub (r : row) (constr : row_constraint) env : constraint_ list =
1924+
and eliminate_row_constraint stage ~lub (r : row) (constr : row_constraint) env : constraint_ list =
1925+
let keep_constr = if is_stage5_up stage then [] else [ Rows_constr { r = [ r ]; constr } ] in
19251926
match r with
19261927
| { bcast = Broadcastable; _ } ->
19271928
(* The environment is unchanged, as apply_row_constraint would update only the constr. *)
1928-
let ineqs, _env = apply_row_constraint ~stage:Stage5 r constr env in
1929+
let ineqs, _env = apply_row_constraint stage r constr env in
19291930
List.concat_map ineqs ~f:(function
1930-
| Rows_constr { r = rows; constr } -> eliminate_rows_constraint ~lub:None rows constr env
1931+
| Rows_constr { r = rows; constr } ->
1932+
eliminate_rows_constraint stage ~lub:None rows constr env
19311933
| ineq -> [ ineq ])
19321934
| { bcast = Row_var { v; beg_dims }; dims; id } -> (
19331935
let r1 = row_of_var v id in
@@ -1936,68 +1938,68 @@ and eliminate_row_constraint ~lub (r : row) (constr : row_constraint) env : cons
19361938
match reduce_row_constraint constr ~beg_dims ~dims with
19371939
| Total_elems { numerator; divided_by } -> (
19381940
match (numerator, divided_by, lub) with
1939-
| Num_elems 1, vs, _ ->
1941+
| Num_elems 1, vs, _ when is_stage5_up stage ->
19401942
no_further_axes
19411943
:: List.map vs ~f:(fun v ->
19421944
let d2 = get_dim ~d:1 () in
19431945
Dim_eq { d1 = Var v; d2 })
1944-
| Num_elems d, [], None ->
1946+
| Num_elems d, [], None when d <> 1 && is_stage3_up stage ->
19451947
let dim = get_dim ~d () in
19461948
[ Row_eq { r1; r2 = { dims = [ dim ]; bcast = Broadcastable; id } } ]
1947-
| Num_elems d, [], Some { dims; _ } when last_dim_is dims d ->
1949+
| Num_elems d, [], Some { dims; _ } when d <> 1 && last_dim_is dims (( = ) d) ->
19481950
let dim = get_dim ~d () in
19491951
[ Row_eq { r1; r2 = { dims = [ dim ]; bcast = Broadcastable; id } } ]
19501952
| Num_elems _, [], Some lub ->
1951-
let ineqs, _env = apply_row_constraint ~stage:Stage5 lub constr env in
1953+
let ineqs, _env = apply_row_constraint stage lub constr env in
19521954
List.concat_map ineqs ~f:(function
19531955
| Rows_constr { r = rows; constr } ->
1954-
eliminate_rows_constraint ~lub:None rows constr env
1956+
eliminate_rows_constraint stage ~lub:None rows constr env
19551957
| ineq -> [ ineq ])
1956-
| Num_elems d, [ v ], _ ->
1958+
| Num_elems d, [ v ], None when is_stage4_up stage ->
19571959
no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ]
1958-
| Strided_var { coeff = _; var = _; denom = _ }, [], _ ->
1960+
| Num_elems d, [ v ], Some ({ dims; _ } as r2)
1961+
when last_dim_is dims (fun d2 -> d % d2 = 0) ->
1962+
let d2 = match List.last dims with Some (Dim { d; _ }) -> d | _ -> assert false in
1963+
let row_eq =
1964+
if d = d2 && is_stage5_up stage then no_further_axes else Row_eq { r1; r2 }
1965+
in
1966+
row_eq :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d:(d / d2) () } ]
1967+
| Strided_var { coeff = _; var = _; denom = _ }, _, _ ->
19591968
(* The row variable should have coeff * var elements total *)
19601969
(* We can't determine the exact shape without knowing var *)
1961-
[]
1962-
| Strided_var { coeff; var; denom }, [ v ], _ when equal_dim_var var v ->
1963-
let coeff = Utils.safe_force coeff in
1964-
if coeff % denom <> 0 then
1965-
raise
1966-
@@ Shape_error
1967-
( [%string "Strided_var constraint: %{coeff#Int} not divisible by %{denom#Int}"],
1968-
[ Row_mismatch [ r ] ] )
1969-
else
1970-
(* coeff * var / var = coeff *)
1971-
no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d:(coeff / denom) () } ]
1972-
| _ -> [])
1970+
(* FIXME: probably can do better here *)
1971+
keep_constr
1972+
| _ -> keep_constr)
19731973
| Exact dims -> [ Row_eq { r1; r2 = { dims; bcast = Broadcastable; id } } ]
1974-
| _ -> [])
1974+
| Unconstrained -> [])
19751975

1976-
let%debug5_sexp close_row_terminal ~(stage : stage) (env : environment)
1976+
let%track5_sexp close_row_terminal ~(stage : stage) (env : environment)
19771977
({ dims; bcast; id } as _r : row) : constraint_ list =
19781978
let suffix () = List.map dims ~f:(fun d -> Terminal_dim d) in
19791979
match bcast with
1980-
| Broadcastable -> if is_stage5_up stage then [] else suffix ()
1980+
| Broadcastable -> if is_stage6_up stage then [] else suffix ()
19811981
| Row_var { v; beg_dims } -> (
19821982
let term_dims () = List.map beg_dims ~f:(fun d -> Terminal_dim d) @ suffix () in
19831983
let r1 : row = row_of_var v id in
19841984
let no_further_axes = Row_eq { r1; r2 = { dims = []; bcast = Broadcastable; id } } in
19851985
match Map.find env.row_env v with
1986-
| Some (Bounds_row { lub = None; constr = Unconstrained; _ }) when is_stage3_up stage ->
1986+
| Some (Bounds_row { lub = None; constr = Unconstrained; _ }) when is_stage4_up stage ->
19871987
[%log6 "terminal row: closing", (_r : row)];
19881988
no_further_axes :: term_dims ()
19891989
| Some (Bounds_row { lub = None; constr; _ })
19901990
when is_stage2_up stage && not (equal_row_constraint constr Unconstrained) ->
19911991
let ineqs =
19921992
(* This is the constraint on the row variable, not on the original row. *)
1993-
try eliminate_row_constraint r1 ~lub:None constr env
1993+
try eliminate_row_constraint stage r1 ~lub:None constr env
19941994
with Shape_error (s, trace) -> raise @@ Shape_error (s, Row_mismatch [ r1 ] :: trace)
19951995
in
1996-
ineqs @ term_dims ()
1996+
(* FIXME: at which stage should we drop the terminal row? *)
1997+
let keep_terminal = if is_stage6_up stage then [] else [ Terminal_row r1 ] in
1998+
ineqs @ term_dims () @ keep_terminal
19971999
| Some (Solved_row _) -> assert false
19982000
| Some (Bounds_row { lub = Some lub; _ }) when is_stage3_up stage ->
19992001
Row_eq { r1; r2 = lub } :: term_dims ()
2000-
| _ when is_stage5_up stage -> []
2002+
| _ when is_stage6_up stage -> []
20012003
| _ ->
20022004
[%log6 "terminal row: keeping", (_r : row), "as", (r1 : row)];
20032005
Terminal_row r1 :: term_dims ())
@@ -2113,7 +2115,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
21132115
| Bounds_row { lub; constr; _ } ->
21142116
(* TODO: should we store the id somewhere? *)
21152117
let id = phantom_row_id in
2116-
eliminate_row_constraint (row_of_var v id) ~lub constr env
2118+
eliminate_row_constraint stage (row_of_var v id) ~lub constr env
21172119
| _ -> []
21182120
in
21192121
let finalizing_entries : constraint_ list =

lib/shape_inference.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ During the solution process, the constraints are incorporated, or propagated, in
161161
The constraints are solved by: unification of the equation constraints, unification-like simplification of the inequality constraints, propagation of the complex constraints. Simplification of an inequality, and constraint propagation, can generate more constraints, so we need to be careful to keep it terminating. The solution proceeds in stages.
162162

163163
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
164-
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
165-
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 5 to all variables.)
166-
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any, otherwise to the lower bound.
164+
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) (FIXME: reconsider this, see the algo for row variables: a new LUB can still be inferred.) Forces coefficients coming from precision byte sizes.
165+
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
166+
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any, otherwise to the lower bound. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.)
167167
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
168168
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality.
169169
* Stage 7 sets all dim resp. row variables remaining in updated shapes to dimension-1 resp. no-further-axes.

0 commit comments

Comments
 (0)