Skip to content

Commit 965ac81

Browse files
lukstaficlaude
andcommitted
Fix projection iterator assignment for Conv_input dimensions
The bug caused "Multiple constraints on the same Conv_input projection" errors when using convolutions with use_padding=true. Root cause: When processing `Iterated (Var v)` equations, fresh iterators were immediately assigned to variables not yet in v_env. This happened before other equations could establish that the variable should get its index from a Conv_input affine expression instead. Fix: 1. Defer `Iterated (Var v)` processing: collect such variables and process them after all equations are handled, when their projections are known 2. Track Conv_input target projections and exclude them from early iterator creation in product_dim processing 3. Create fresh iterators for remaining product dimensions only after p_conv_input processing completes This ensures projections that should get affine indices from Conv_input don't conflict with prematurely assigned iterators. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent d24a62e commit 965ac81

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

tensor/row.ml

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3801,6 +3801,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
38013801
let verify_when_solved1 = ref [] in
38023802
let verify_when_solved2 = ref [] in
38033803
let p_dims = ref [] in
3804+
let iterated_vars = ref [] in
38043805
let proj_classes = ref @@ Map.empty (module Proj_id) in
38053806
let rec loop (eq : proj_equation) : unit =
38063807
match eq with
@@ -3859,11 +3860,21 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
38593860
| Iterated (Var v) -> (
38603861
match Hashtbl.find v_env v with
38613862
| None ->
3862-
let idx = Idx.(Iterator (get_symbol ())) in
3863-
Hashtbl.add_exn v_env ~key:v ~data:(Solved idx)
3863+
(* Defer: record that v needs iteration, will resolve after all equations processed *)
3864+
iterated_vars := v :: !iterated_vars
38643865
| Some proj -> loop @@ Iterated proj)
38653866
in
38663867
List.iter eqs ~f:loop;
3868+
(* Process deferred iterated variables: they should now have projections assigned *)
3869+
List.iter !iterated_vars ~f:(fun v ->
3870+
match Hashtbl.find v_env v with
3871+
| None ->
3872+
raise
3873+
@@ Shape_error
3874+
( "Iterated variable has no projection assigned: "
3875+
^ Sexp.to_string_hum ([%sexp_of: dim_var] v),
3876+
[] )
3877+
| Some proj -> loop @@ Iterated proj);
38673878
let projs = ref @@ Map.empty (module Proj_id) in
38683879
List.iter !p_solved ~f:(fun (p, idx) ->
38693880
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
@@ -3873,6 +3884,14 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
38733884
@@ Shape_error
38743885
("Multiple constraints on the same projection", [ Index_mismatch [ idx; idx2 ] ])));
38753886
let product_dim = ref @@ Map.empty (module Proj_id) in
3887+
(* Collect projection IDs that will get their index from Conv_input (target_id projections).
3888+
These should NOT get fresh iterators from product_dim processing. *)
3889+
let conv_input_targets =
3890+
Set.of_list (module Proj_id)
3891+
@@ List.filter_map !p_conv_input ~f:(fun (p, _) ->
3892+
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
3893+
Some repr)
3894+
in
38763895
List.iter !p_dims ~f:(fun (p, d) ->
38773896
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
38783897
if Idx.iterated d && (not @@ Map.mem !projs repr) then
@@ -3885,9 +3904,12 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
38853904
"Conflicting dimensions for the same projection: %{p#Proj_id} %{d#Int} \
38863905
%{d2#Int}"],
38873906
[] )));
3907+
(* Create fresh iterators for product dimensions, EXCEPT for those that will get
3908+
their index from Conv_input (they will be processed later). *)
38883909
Map.iteri !product_dim ~f:(fun ~key:p ~data:_ ->
38893910
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
3890-
Utils.mref_add_missing projs repr ~f:(fun () -> Idx.(Iterator (get_symbol ()))));
3911+
if not (Set.mem conv_input_targets repr) then
3912+
Utils.mref_add_missing projs repr ~f:(fun () -> Idx.(Iterator (get_symbol ()))));
38913913

38923914
(* Process p_conv_input to populate projs and compute padding *)
38933915
let resolved_padding =
@@ -3943,6 +3965,12 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
39433965
("Cannot unify two Conv_input projections", [ Index_mismatch [ idx1; idx2 ] ])
39443966
with _ -> () (* Ignore errors for now *));
39453967

3968+
(* Now create fresh iterators for product dimensions that still don't have an index.
3969+
This is done after p_conv_input processing so Conv_input projections don't conflict. *)
3970+
Map.iteri !product_dim ~f:(fun ~key:p ~data:_ ->
3971+
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
3972+
Utils.mref_add_missing projs repr ~f:(fun () -> Idx.(Iterator (get_symbol ()))));
3973+
39463974
{
39473975
v_env;
39483976
proj_classes = !proj_classes;

0 commit comments

Comments
 (0)