Skip to content

Commit 4e85a0f

Browse files
committed
Fix: Shape.set_terminal for parameters
1 parent 02b9ff8 commit 4e85a0f

File tree

5 files changed

+25
-3
lines changed

5 files changed

+25
-3
lines changed

docs/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ Simplification of an inequality, and constraint propagation, can generate more c
174174

175175
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
176176
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) (FIXME: reconsider this, see the algo for row variables: a new LUB can still be inferred.) Forces coefficients coming from precision byte sizes.
177-
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint. In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock.
177+
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds LUB (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint. In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock.
178178
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
179179
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
180180
* Stage 6 sets row variables in the remaining inequalities and updated shapes to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.

tensor/shape.ml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,23 @@ let set_equal delayed_ref1 delayed_ref2 =
13771377
}
13781378
:: !active_constraints
13791379

1380+
let set_terminal sh =
1381+
let get_origin kind =
1382+
Row.
1383+
{
1384+
lhs_name = sh.debug_name;
1385+
lhs_kind = kind;
1386+
rhs_name = "(parameter)";
1387+
rhs_kind = kind;
1388+
operation = Some "set_terminal";
1389+
}
1390+
in
1391+
active_constraints :=
1392+
Row.Terminal_row (sh.batch, [ get_origin `Batch ])
1393+
:: Row.Terminal_row (sh.input, [ get_origin `Input ])
1394+
:: Row.Terminal_row (sh.output, [ get_origin `Output ])
1395+
:: !active_constraints
1396+
13801397
let unsafe_reinitialize () =
13811398
update_uid := 0;
13821399
state := Row.empty_env;
@@ -1473,7 +1490,6 @@ let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
14731490
iter_shapes update_step ~f:(apply_env_t !state);
14741491
let _, ineqs = get_inequalities update_step in
14751492
active_update_steps := update_step :: !active_update_steps;
1476-
let _debug_new_active_update_steps : update_step list = !active_update_steps in
14771493
active_constraints := ineqs @ !active_constraints;
14781494
let ineqs', env = Row.solve_inequalities ~stage:Row.Stage1 ineqs !state in
14791495
let _debug_remaining_constraints : Row.constraint_ list = ineqs' in

tensor/shape.mli

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ val unsafe_reinitialize : unit -> unit
176176

177177
(** {2 Internal-ish API.} *)
178178

179+
val set_terminal : t -> unit
180+
(** Mark the shape as terminal, so that its rows can be closed to Least Upper Bounds (LUBs). This
181+
function is only intended for parameters shapes, which would otherwise not be terminal because
182+
of the initialization expressions of the parameters. *)
183+
179184
(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the
180185
tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape. *)
181186
type logic =

tensor/tensor.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ let%debug7_sexp param ~t (name : string) ?(more_label = []) ?input_dims ?output_
589589
update computations. *)
590590
let g = (Option.value_exn ~here:[%here] t.diff).grad in
591591
Tn.update_memory_mode g Never_virtual 26;
592+
Shape.set_terminal t.shape;
592593
remove_fwd_root t;
593594
{ t with params = Set.singleton (module T) t }
594595

test/einsum/test_einsum_capture.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Test 1 - Matrix multiply with constraint-driven shapes:
7373
Captured dimensions: i=3, j=4, k=5
7474

7575
Test 2 - Chain operations with constraint propagation:
76-
base inferred shape: 1,1
76+
base inferred shape: 6,8
7777
transposed inferred shape: 8,6
7878
multiplied inferred shape: 8,10
7979
final inferred shape: 6,10

0 commit comments

Comments
 (0)