Skip to content

Commit 324a705

Browse files
committed
Fixes #334 and Optimize shape inference #215: No more traversing of the whole env with eliminate_variables
1 parent f91ec3c commit 324a705

File tree

4 files changed

+77
-84
lines changed

4 files changed

+77
-84
lines changed

lib/row.ml

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ type constraint_ =
227227
| Rows_constr of { r : t list; constr : row_constraint }
228228
| Terminal_dim of dim
229229
| Terminal_row of t
230+
| Shape_row of t
230231
[@@deriving compare, equal, sexp_of, variants]
231232

232233
type stage = Stage1 | Stage2 | Stage3 | Stage4 | Stage5 | Stage6 | Stage7
@@ -1390,8 +1391,8 @@ let%debug5_sexp rec unify_row ~stage (eq : t * t) (env : environment) :
13901391
| Row_eq { r1; r2 } ->
13911392
let more_ineqs, env = unify_row ~stage (r1, r2) env in
13921393
(more_ineqs @ ineqs, env)
1393-
| (Dim_ineq _ | Row_ineq _ | Dim_constr _ | Rows_constr _ | Terminal_dim _ | Terminal_row _) as
1394-
ineq ->
1394+
| ( Dim_ineq _ | Row_ineq _ | Dim_constr _ | Rows_constr _ | Terminal_dim _ | Terminal_row _
1395+
| Shape_row _ ) as ineq ->
13951396
(ineq :: ineqs, env)
13961397
in
13971398
let unify_suffix init dims1 dims2 len =
@@ -1993,7 +1994,7 @@ let%track5_sexp rec eliminate_rows_constraint ~depth stage ~lub (rows : row list
19931994
else
19941995
match rows_to_row_or_vars rows with
19951996
| Either.First single_row ->
1996-
eliminate_row_constraint ~depth:(depth + 1) stage ~lub single_row constr env
1997+
eliminate_row_constraint ~depth:(depth + 1) stage ~terminal:false ~lub single_row constr env
19971998
| Either.Second (_all_dims, row_vars) -> (
19981999
let rev_row_vars = List.rev row_vars in
19992000
match
@@ -2020,7 +2021,7 @@ let%track5_sexp rec eliminate_rows_constraint ~depth stage ~lub (rows : row list
20202021
(other_eqs @ ineqs, env)
20212022
| _ -> ([ Rows_constr { r = rows; constr } ], env))
20222023

2023-
and eliminate_row_constraint ~depth stage ~lub (r : row) (constr : row_constraint) env :
2024+
and eliminate_row_constraint ~depth stage ~terminal ~lub (r : row) (constr : row_constraint) env :
20242025
constraint_ list * environment =
20252026
let keep_constr () =
20262027
let ineqs, env = apply_row_constraint ~depth stage r constr env in
@@ -2055,7 +2056,11 @@ and eliminate_row_constraint ~depth stage ~lub (r : row) (constr : row_constrain
20552056
let dim = get_dim ~d () in
20562057
([ Row_eq { r1; r2 = { dims = [ dim ]; bcast = Broadcastable; id } } ], env)
20572058
| Num_elems _, [], Some lub ->
2058-
let ineqs, env = apply_row_constraint ~depth:(depth + 1) stage lub constr env in
2059+
let ineqs, env =
2060+
apply_row_constraint ~depth:(depth + 1) stage
2061+
(if terminal then lub else r)
2062+
constr env
2063+
in
20592064
List.fold ineqs ~init:([], env) ~f:(fun (ineqs, env) ineq ->
20602065
match ineq with
20612066
| Rows_constr { r = rows; constr } ->
@@ -2101,6 +2106,7 @@ and eliminate_row_constraint ~depth stage ~lub (r : row) (constr : row_constrain
21012106
let%track5_sexp close_row_terminal ~(stage : stage) (env : environment)
21022107
({ dims; bcast; id } as _r : row) : constraint_ list =
21032108
let suffix () = List.map dims ~f:(fun d -> Terminal_dim d) in
2109+
(* TODO: can this be simplified? Should we return the environment? *)
21042110
match bcast with
21052111
| Broadcastable -> if is_stage6_up stage then [] else suffix ()
21062112
| Row_var { v; beg_dims } -> (
@@ -2115,7 +2121,7 @@ let%track5_sexp close_row_terminal ~(stage : stage) (env : environment)
21152121
when is_stage2_up stage && not (equal_row_constraint constr Unconstrained) ->
21162122
let ineqs, _env =
21172123
(* This is the constraint on the row variable, not on the original row. *)
2118-
try eliminate_row_constraint ~depth:0 stage r1 ~lub:None constr env
2124+
try eliminate_row_constraint ~depth:0 stage r1 ~terminal:true ~lub:None constr env
21192125
with Shape_error (s, trace) -> raise @@ Shape_error (s, Row_mismatch [ r1 ] :: trace)
21202126
in
21212127
(* FIXME: at which stage should we drop the terminal row? *)
@@ -2134,48 +2140,62 @@ let%track5_sexp close_row_terminal ~(stage : stage) (env : environment)
21342140

21352141
let%debug5_sexp eliminate_dim_entry ~final v ~lub constr =
21362142
match (lub, constr) with
2137-
| _, Unconstrained_dim | _, At_least_dim 1 -> None
21382143
| Some (Dim { d; _ } as lub), At_least_dim d2 when d2 > d ->
21392144
raise
21402145
@@ Shape_error
21412146
( [%string "dereferenced at dimension %{d2#Int}, higher than use site"],
21422147
[ Dim_mismatch [ lub; Var v ] ] )
2148+
| Some _, At_least_dim 0 ->
2149+
if final then Some (Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () }) else None
21432150
| Some lub, At_least_dim _ -> Some (Dim_eq { d1 = Var v; d2 = lub })
2144-
| None, At_least_dim d -> if final then Some (Dim_eq { d1 = Var v; d2 = get_dim ~d () }) else None
2151+
| None, At_least_dim d when final -> Some (Dim_eq { d1 = Var v; d2 = get_dim ~d () })
2152+
| _ when final -> Some (Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () })
2153+
| _ -> None
21452154

2146-
let%debug5_sexp eliminate_variables (env : environment) ({ dims; bcast; id } as _r : row) :
2147-
constraint_ list =
2148-
let f = function
2149-
| Var v as d1 ->
2150-
Some
2151-
(match Map.find env.dim_env v with
2152-
| Some (Bounds_dim { lub; constr; _ }) ->
2153-
Option.value_or_thunk (eliminate_dim_entry ~final:true v ~lub constr)
2154-
~default:(fun () -> Dim_eq { d1; d2 = get_dim ~d:1 () })
2155-
| Some (Solved_dim _) -> assert false
2156-
| None -> Dim_eq { d1; d2 = get_dim ~d:1 () })
2157-
| _ -> None
2155+
let%track5_sexp process_shape_row ~(stage : stage) (env : environment)
2156+
({ dims; bcast; id } as r : row) : constraint_ list * environment =
2157+
let final = is_stage7 stage in
2158+
let rec finalize_upper_lower_bound = function
2159+
| Dim _ -> []
2160+
| Conv_input { output; kernel; _ } ->
2161+
finalize_upper_lower_bound output @ finalize_upper_lower_bound kernel
2162+
| Var v -> (
2163+
match Map.find env.dim_env v with
2164+
| Some (Bounds_dim { lub; constr; _ }) when is_stage4_up stage ->
2165+
Option.to_list @@ eliminate_dim_entry ~final v ~lub constr
2166+
| Some (Solved_dim _) -> assert false
2167+
| _ when final -> [ Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () } ]
2168+
| _ -> [])
2169+
in
2170+
let rec has_dim_var = function
2171+
| Dim _ -> false
2172+
| Conv_input { output; kernel; _ } -> has_dim_var output || has_dim_var kernel
2173+
| Var _ -> true
21582174
in
2159-
let suffix = List.filter_map dims ~f in
2175+
let process_dims dims = List.concat_map dims ~f:finalize_upper_lower_bound in
21602176
match bcast with
2161-
| Broadcastable -> suffix
2177+
| Broadcastable ->
2178+
let keep = if (not final) && List.exists dims ~f:has_dim_var then [ Shape_row r ] else [] in
2179+
(keep @ process_dims dims, env)
21622180
| Row_var { v; beg_dims } -> (
2163-
let elim_dims = List.filter_map beg_dims ~f @ suffix in
2164-
let r2 = { dims = []; bcast = Broadcastable; id } in
2165-
let elim_var = Row_eq { r1 = row_of_var v id; r2 } in
2181+
let dim_eqs = process_dims beg_dims @ process_dims dims in
2182+
let r1 : row = row_of_var v id in
21662183
match Map.find env.row_env v with
2167-
| Some (Bounds_row { constr = Total_elems { numerator = Num_elems 1; _ }; _ }) ->
2168-
elim_var :: elim_dims
2169-
| Some (Bounds_row { constr = Total_elems _ as constr; lub; _ }) ->
2170-
let stage = Stage7 in
2171-
let ineqs, _env =
2172-
eliminate_row_constraint ~depth:0 stage ~lub (row_of_var v id)
2173-
(subst_row_constraint stage env constr)
2174-
env
2184+
| Some (Bounds_row { constr = Unconstrained; _ }) when not final ->
2185+
(Shape_row r :: dim_eqs, env)
2186+
| Some (Bounds_row { constr = Unconstrained; _ }) when final ->
2187+
(Row_eq { r1; r2 = { dims = []; bcast = Broadcastable; id } } :: dim_eqs, env)
2188+
| Some (Bounds_row { lub; constr; _ }) ->
2189+
let ineqs, env =
2190+
try eliminate_row_constraint ~depth:0 stage r1 ~terminal:false ~lub constr env
2191+
with Shape_error (s, trace) -> raise @@ Shape_error (s, Row_mismatch [ r1 ] :: trace)
21752192
in
2176-
reapply_rows_constr := false;
2177-
ineqs @ elim_dims
2178-
| _ -> elim_var :: elim_dims)
2193+
let keep = if not final then [ Shape_row r ] else [] in
2194+
(keep @ ineqs @ dim_eqs, env)
2195+
| Some (Solved_row _) -> assert false
2196+
| _ when final ->
2197+
(Row_eq { r1; r2 = { dims = []; bcast = Broadcastable; id } } :: dim_eqs, env)
2198+
| _ -> (Shape_row r :: dim_eqs, env))
21792199

21802200
let empty_env = { dim_env = Map.empty (module Dim_var); row_env = Map.empty (module Row_var) }
21812201

@@ -2233,6 +2253,9 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
22332253
| Terminal_row r ->
22342254
let more_ineqs = close_row_terminal ~stage env @@ subst_row env r in
22352255
(more_ineqs @ ineqs, env)
2256+
| Shape_row r ->
2257+
let more_ineqs, env = process_shape_row ~stage env @@ subst_row env r in
2258+
(more_ineqs @ ineqs, env)
22362259
in
22372260
let ineqs', env = List.fold ineqs ~init:([], env) ~f in
22382261
let ineqs' = List.rev ineqs' in
@@ -2242,33 +2265,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
22422265
then (ineqs', env)
22432266
else solve ineqs' env
22442267
in
2245-
match stage with
2246-
| Stage1 | Stage2 | Stage3 | Stage6 | Stage7 -> solve ineqs env
2247-
| Stage4 ->
2248-
let finalize_upper_lower_bound (v : dim_var) = function
2249-
| Bounds_dim { lub; constr; _ } ->
2250-
Option.to_list @@ eliminate_dim_entry ~final:false v ~lub constr
2251-
| _ -> []
2252-
in
2253-
let finalizing_entries : constraint_ list =
2254-
Map.fold env.dim_env ~init:[] ~f:(fun ~key ~data accu ->
2255-
finalize_upper_lower_bound key data @ accu)
2256-
in
2257-
solve (finalizing_entries @ ineqs) env
2258-
| Stage5 ->
2259-
let finalize_total_elems env v = function
2260-
| Bounds_row { lub; constr; _ } ->
2261-
(* TODO: should we store the id somewhere? *)
2262-
let id = phantom_row_id in
2263-
eliminate_row_constraint ~depth:0 stage (row_of_var v id) ~lub constr env
2264-
| _ -> ([], env)
2265-
in
2266-
let finalizing_entries, env =
2267-
Map.fold env.row_env ~init:([], env) ~f:(fun ~key ~data (accu, env) ->
2268-
let ineqs, env = finalize_total_elems env key data in
2269-
(ineqs @ accu, env))
2270-
in
2271-
solve (finalizing_entries @ ineqs) env
2268+
solve ineqs env
22722269

22732270
let rec row_to_labels env =
22742271
let rec f = function
@@ -2434,6 +2431,7 @@ let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_e
24342431
| eq -> [ eq ])
24352432
| Terminal_dim d -> [ Iterated (to_proj d) ]
24362433
| Terminal_row { dims; _ } -> List.map ~f:(fun d -> Iterated (to_proj d)) dims
2434+
| Shape_row _ -> []
24372435
| Rows_constr
24382436
{ r; constr = Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by } }
24392437
-> (

lib/row.mli

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ type constraint_ =
125125
affect the constraint (i.e. there is no "subtyping", it resembles Row_eq). *)
126126
| Terminal_dim of dim
127127
| Terminal_row of t
128+
(** A row of the shape of a terminal tensor (i.e. a tensor that does not have sub-tensors). *)
129+
| Shape_row of t (** A row of a shape of interest. *)
128130
[@@deriving compare, equal, sexp_of, variants]
129131

130132
type stage = Stage1 | Stage2 | Stage3 | Stage4 | Stage5 | Stage6 | Stage7
@@ -133,7 +135,6 @@ type stage = Stage1 | Stage2 | Stage3 | Stage4 | Stage5 | Stage6 | Stage7
133135
val subst_row : environment -> t -> t
134136
val unify_row : stage:stage -> t * t -> environment -> constraint_ list * environment
135137
val empty_env : environment
136-
val eliminate_variables : environment -> t -> constraint_ list
137138

138139
val solve_inequalities :
139140
stage:stage -> constraint_ list -> environment -> constraint_ list * environment

lib/shape.ml

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -686,43 +686,37 @@ let apply_env_t env sh =
686686
sh.input <- Row.subst_row env sh.input;
687687
sh.output <- Row.subst_row env sh.output
688688

689-
let apply_env_update ~eliminate_variables env update_step =
690-
iter_shapes update_step ~f:(apply_env_t env);
691-
if eliminate_variables then
692-
List.concat_map ~f:(Row.eliminate_variables env) @@ all_rows update_step
693-
else []
694-
695689
let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
696690
(* Allow the derivation of constraints to depend on the shapes (currently, only Batch_slice
697691
does). *)
698-
assert (List.is_empty (apply_env_update ~eliminate_variables:false !state update_step));
692+
iter_shapes update_step ~f:(apply_env_t !state);
699693
let _, ineqs = get_inequalities update_step in
700694
active_update_steps := update_step :: !active_update_steps;
695+
let _debug_new_active_update_steps : update_step list = !active_update_steps in
701696
active_constraints := ineqs @ !active_constraints;
702697
let ineqs', env = Row.solve_inequalities ~stage:Row.Stage1 ineqs !state in
703698
let _debug_remaining_constraints : Row.constraint_ list = ineqs' in
704-
assert (List.is_empty (apply_env_update ~eliminate_variables:false env update_step));
699+
iter_shapes update_step ~f:(apply_env_t env);
705700
state := env
706701

707702
let%debug4_sexp finish_inference (() : unit) : unit =
708703
(* TODO: optimize to keep all needed information in unsolved, rather than starting with all
709704
constraints. *)
710705
let unsolved, env = Row.solve_inequalities ~stage:Stage2 !active_constraints !state in
711706
let unsolved, env = Row.solve_inequalities ~stage:Stage3 unsolved env in
707+
let unsolved =
708+
List.concat_map
709+
~f:(fun update_step -> List.map ~f:(fun r -> Row.Shape_row r) @@ all_rows update_step)
710+
!active_update_steps
711+
@ unsolved
712+
in
712713
let unsolved, env = Row.solve_inequalities ~stage:Stage4 unsolved env in
713714
let unsolved, env = Row.solve_inequalities ~stage:Stage5 unsolved env in
714715
let unsolved, env = Row.solve_inequalities ~stage:Stage6 unsolved env in
715716
let unsolved, env = Row.solve_inequalities ~stage:Stage7 unsolved env in
716-
let _active_update_steps : update_step list = !active_update_steps in
717-
let eliminated =
718-
List.concat_map ~f:(apply_env_update ~eliminate_variables:true env) !active_update_steps
719-
in
720-
let unsolved, env = Row.solve_inequalities ~stage:Stage7 (eliminated @ unsolved) env in
721717
assert (List.is_empty unsolved);
722-
List.iter
723-
~f:(fun update_step ->
724-
assert (List.is_empty (apply_env_update ~eliminate_variables:false env update_step)))
725-
!active_update_steps;
718+
let _active_update_steps : update_step list = !active_update_steps in
719+
List.iter ~f:(iter_shapes ~f:(apply_env_t env)) !active_update_steps;
726720
let _applied_update_steps : update_step list = !active_update_steps in
727721
active_constraints := [];
728722
active_update_steps := [];

lib/shape_inference.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ The constraints are solved by: unification of the equation constraints, unificat
163163
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
164164
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) (FIXME: reconsider this, see the algo for row variables: a new LUB can still be inferred.) Forces coefficients coming from precision byte sizes.
165165
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
166-
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.)
166+
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
167167
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
168-
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
169-
* Stage 7 sets all dim variables remaining in updated shapes to the lower bound if they have any, otherwise to dimension-1. It sets all row variables remaining in updated shapes to no-further-axes.
168+
* Stage 6 sets row variables in the remaining inequalities and updated shapes to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
169+
* Stage 7 sets all dim variables remaining in updated shapes to the lower bound if they have any, otherwise to dimension-1.
170170

171171
Let's explain the shape inference functions.
172172

0 commit comments

Comments
 (0)