Skip to content

Commit f027542

Browse files
committed
FIxed remaining cases of empty origin, mostly by Claude
Summary by Claude: The origins are now properly propagating to error messages. The shape error in the transformer test now shows the chain of operations with their names and kinds, which will help users debug shape mismatches much more effectively. The changes we made: 1. Added ~origin parameter to row_conjunction and check_empty_row functions 2. Passed origins through all constraint generation sites 3. Fixed all empty origin occurrences to use appropriate origins from context The error message now clearly shows the provenance chain, making it much easier to understand where shape conflicts originate.
1 parent a1351ea commit f027542

File tree

1 file changed

+44
-41
lines changed

1 file changed

+44
-41
lines changed

tensor/row.ml

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ type source = Direct | Equation | Cur | Subr [@@deriving equal, sexp]
291291
let rec interleave l1 l2 =
292292
match (l1, l2) with [], l | l, [] -> l | h1 :: t1, h2 :: t2 -> h1 :: h2 :: interleave t1 t2
293293

294-
let merge_origins o1 o2 = List.take (interleave o1 o2) 6
294+
let merge_origins o1 o2 =
295+
let o = List.dedup_and_sort ~compare:compare_constraint_origin @@ interleave o1 o2 in
296+
List.take o 10
295297

296298
let dim_to_int_exn = function
297299
| Dim { d; _ } -> d
@@ -382,7 +384,7 @@ let collect_factors dims =
382384

383385
let known_dims_product dims = match collect_factors dims with Some (_, []) -> true | _ -> false
384386

385-
let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
387+
let rec row_conjunction ?(id = phantom_row_id) ~origin stage constr1 constr2 =
386388
let elems_mismatch n1 n2 =
387389
raise
388390
@@ Shape_error
@@ -415,7 +417,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
415417
(* So var = n * denom / coeff *)
416418
if n * denom % coeff_val = 0 then
417419
Some
418-
( [ Dim_eq { d1 = Var var; d2 = get_dim ~d:(n * denom / coeff_val) (); origin = [] } ],
420+
( [ Dim_eq { d1 = Var var; d2 = get_dim ~d:(n * denom / coeff_val) (); origin } ],
419421
constr1 )
420422
else
421423
(* n * denom is not divisible by coeff - this is a mismatch *)
@@ -448,7 +450,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
448450
d2 =
449451
Conv_input
450452
{ stride = k; output = Var v2; dilation = 0; kernel = get_dim ~d:0 () };
451-
origin = [];
453+
origin;
452454
};
453455
],
454456
constr2 )
@@ -464,7 +466,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
464466
d2 =
465467
Conv_input
466468
{ stride = k; output = Var v1; dilation = 0; kernel = get_dim ~d:0 () };
467-
origin = [];
469+
origin;
468470
};
469471
],
470472
constr1 )
@@ -481,15 +483,15 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
481483
(* Variable appears in both numerator and denominator, they cancel out *)
482484
(* (c1 * v1 / d1) / (... * v1 * ...) = c1 / (d1 * ... * ...) *)
483485
let vars1' = remove_var v1 vars1 in
484-
row_conjunction ~id stage
486+
row_conjunction ~id ~origin stage
485487
(Total_elems { numerator = Num_elems (Utils.safe_force c1); divided_by = vars1' })
486488
constr2
487489
| ( constr2,
488490
Total_elems
489491
{ numerator = Strided_var { coeff = c1; var = v1; denom = _ }; divided_by = vars1 } )
490492
when List.mem vars1 v1 ~equal:equal_dim_var && late ->
491493
let vars1' = remove_var v1 vars1 in
492-
row_conjunction ~id stage
494+
row_conjunction ~id ~origin stage
493495
(Total_elems { numerator = Num_elems (Utils.safe_force c1); divided_by = vars1' })
494496
constr2
495497
| ( Total_elems { numerator = n1; divided_by = vars1 },
@@ -522,12 +524,12 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
522524
| None -> elems_mismatch n1 n2
523525
| Some v ->
524526
if quotient <= 0 then elems_mismatch n1 n2;
525-
[ Dim_eq { d1 = Var v; d2 = get_dim ~d:quotient (); origin = [] } ])
527+
[ Dim_eq { d1 = Var v; d2 = get_dim ~d:quotient (); origin } ])
526528
else if quotient <= 0 && Option.is_none num_var then elems_mismatch n1 n2
527529
else if quotient = 1 && Option.is_none num_var then
528530
(* The difference variables must all be 1 *)
529531
List.map diff_vars ~f:(fun v ->
530-
Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin = [] })
532+
Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin })
531533
else
532534
(* The product of difference variables equals the quotient *)
533535
let r = { dims = List.map diff_vars ~f:(fun v -> Var v); bcast = Broadcastable; id } in
@@ -540,7 +542,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
540542
in
541543
[
542544
Rows_constr
543-
{ r = [ r ]; constr = Total_elems { numerator; divided_by = [] }; origin = [] };
545+
{ r = [ r ]; constr = Total_elems { numerator; divided_by = [] }; origin };
544546
]
545547
in
546548
let lazy_extras ~keep_constr1 ~num_var ?(extra_var = []) ~coeff ~denom () =
@@ -554,7 +556,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
554556
let constr =
555557
Total_elems { numerator = Strided_var { coeff; var = num_var; denom }; divided_by = [] }
556558
in
557-
[ Rows_constr { r = [ r ]; constr; origin = [] } ]
559+
[ Rows_constr { r = [ r ]; constr; origin } ]
558560
in
559561
let extras ~keep_constr1 : _ option =
560562
match (n1, n2) with
@@ -612,7 +614,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
612614
];
613615
] )
614616
else
615-
let eqs = List.map2_exn dims1 dims2 ~f:(fun d1 d2 -> Dim_eq { d1; d2; origin = [] }) in
617+
let eqs = List.map2_exn dims1 dims2 ~f:(fun d1 d2 -> Dim_eq { d1; d2; origin }) in
616618
Some (eqs, constr1)
617619
| Total_elems { numerator; divided_by }, Exact dims
618620
| Exact dims, Total_elems { numerator; divided_by } -> (
@@ -638,24 +640,24 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
638640
(* reminder is 1: equate all variables on both sides to 1 *)
639641
let divided_by_eqs =
640642
List.map divided_by ~f:(fun v ->
641-
Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin = [] })
643+
Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin })
642644
in
643645
let exact_vars_eqs =
644646
List.map vars ~f:(fun v ->
645-
Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin = [] })
647+
Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin })
646648
in
647649
Some (divided_by_eqs @ exact_vars_eqs, Exact dims)
648650
else if List.is_empty divided_by && List.length vars = 1 && reminder > 0 then
649651
(* divided_by is empty and there is only one dim variable in Exact dims *)
650652
let v = List.hd_exn vars in
651653
Some
652-
([ Dim_eq { d1 = Var v; d2 = get_dim ~d:reminder (); origin = [] } ], Exact dims)
654+
([ Dim_eq { d1 = Var v; d2 = get_dim ~d:reminder (); origin } ], Exact dims)
653655
else if List.is_empty vars && List.length divided_by = 1 && reminder > 0 then
654656
(* Exact dims contain only known dimensions and divided_by has exactly one
655657
variable *)
656658
let v = List.hd_exn divided_by in
657659
Some
658-
([ Dim_eq { d1 = Var v; d2 = get_dim ~d:reminder (); origin = [] } ], Exact dims)
660+
([ Dim_eq { d1 = Var v; d2 = get_dim ~d:reminder (); origin } ], Exact dims)
659661
else None
660662
| Strided_var { coeff; var; denom } ->
661663
if known_product = 0 then
@@ -666,7 +668,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
666668
let coeff_val = Utils.safe_force coeff in
667669
if known_product * denom % coeff_val = 0 then
668670
let d = known_product * denom / coeff_val in
669-
Some ([ Dim_eq { d1 = Var var; d2 = get_dim ~d (); origin = [] } ], Exact dims)
671+
Some ([ Dim_eq { d1 = Var var; d2 = get_dim ~d (); origin } ], Exact dims)
670672
else elems_mismatch numerator (Num_elems known_product)
671673
else if
672674
late
@@ -695,7 +697,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
695697
let coefficient = known_product * denom / coeff_val in
696698
if coefficient = 1 then
697699
(* Simple equality *)
698-
Some ([ Dim_eq { d1 = Var var; d2 = Var single_var; origin = [] } ], Exact dims)
700+
Some ([ Dim_eq { d1 = Var var; d2 = Var single_var; origin } ], Exact dims)
699701
else if coefficient > 1 then
700702
(* Use Conv_input with stride and dilation=0 *)
701703
Some
@@ -711,7 +713,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
711713
dilation = 0;
712714
kernel = get_dim ~d:0 ();
713715
};
714-
origin = [];
716+
origin;
715717
};
716718
],
717719
Exact dims )
@@ -758,7 +760,7 @@ let%track5_sexp rec apply_dim_constraint ~(source : source) ~(stage : stage) (di
758760
(* FIXME: *)
759761
(* match (dim, constr, stage) with
760762
| Var _, At_least_dim d, Stage4 ->
761-
(Dim_eq { d1 = dim; d2 = get_dim ~d () ; origin = [] } :: extras, Unconstrained_dim)
763+
(Dim_eq { d1 = dim; d2 = get_dim ~d () ; origin } :: extras, Unconstrained_dim)
762764
| _ -> *)
763765
(extras, constr)
764766

@@ -886,7 +888,7 @@ let rows_to_row_or_vars (rows : row list) : (row, dim list * (row_var * row_id)
886888

887889
let row_of_var v id = { dims = []; bcast = Row_var { v; beg_dims = [] }; id }
888890

889-
let check_empty_row r =
891+
let check_empty_row ~origin r =
890892
if not (List.is_empty r.dims) then
891893
raise @@ Shape_error ("check_empty_row: row is not empty", [ Row_mismatch [ r ] ]);
892894
match r.bcast with
@@ -898,7 +900,7 @@ let check_empty_row r =
898900
{
899901
r1 = row_of_var v r.id;
900902
r2 = { dims = []; bcast = Broadcastable; id = r.id };
901-
origin = [];
903+
origin;
902904
};
903905
]
904906
else raise @@ Shape_error ("check_empty_row: row is not empty", [ Row_mismatch [ r ] ])
@@ -1203,15 +1205,15 @@ let%track5_sexp rec apply_rows_constraint ~depth ~stage origin (rows : row list)
12031205
r2 = { dims = [ single_dim ]; bcast = Broadcastable; id };
12041206
origin;
12051207
}
1206-
:: List.concat_map ~f:check_empty_row more_rows,
1208+
:: List.concat_map ~f:(check_empty_row ~origin) more_rows,
12071209
env )
12081210
| { dims = _; bcast = Row_var { v = _; beg_dims = _ }; id = { kind = `Output; _ } } :: _
12091211
->
12101212
assert false
12111213
| _ ->
12121214
raise @@ Shape_error ("apply_rows_constraint: shape too big", [ Row_mismatch rows ])
12131215
)
1214-
| _ -> ([ Rows_constr { r = rows; constr; origin = [] } ], env))
1216+
| _ -> ([ Rows_constr { r = rows; constr; origin } ], env))
12151217

12161218
and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint) env :
12171219
constraint_ list * _ =
@@ -1261,8 +1263,9 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
12611263
true,
12621264
false )
12631265
| Some (Bounds_row bounds) -> (
1266+
let origin = merge_origins origin bounds.origin in
12641267
match
1265-
row_conjunction ~id:r.id stage (reduce constr ~beg_dims ~dims) bounds.constr
1268+
row_conjunction ~id:r.id ~origin stage (reduce constr ~beg_dims ~dims) bounds.constr
12661269
with
12671270
| None -> ([], constr, env, false, false)
12681271
| Some (extras, constr) ->
@@ -1325,7 +1328,7 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
13251328
(Dim_eq { d1 = Var v; d2 = get_dim ~d:n (); origin } :: extras, env)
13261329
| Num_elems 1, vs1, vs2 ->
13271330
( List.map
1328-
~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin = [] })
1331+
~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 (); origin })
13291332
(vs1 @ vs2)
13301333
@ extras,
13311334
env )
@@ -1334,7 +1337,7 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
13341337
(* Total = (coeff * v / denom) / v = coeff / denom *)
13351338
if Utils.safe_force coeff % denom = 0 then
13361339
( Dim_eq
1337-
{ d1 = Var v; d2 = get_dim ~d:(Utils.safe_force coeff / denom) (); origin = [] }
1340+
{ d1 = Var v; d2 = get_dim ~d:(Utils.safe_force coeff / denom) (); origin }
13381341
:: extras,
13391342
env )
13401343
else if
@@ -1380,7 +1383,7 @@ and apply_row_constraint ~depth stage origin (r : row) (constr : row_constraint)
13801383
List.filter lub_dims ~f:(function Dim { d; _ } -> d > 1 | _ -> false)
13811384
in
13821385
if List.length greater_than_one <= 1 then
1383-
(Row_eq { r1 = row_of_var v r.id; r2 = lub; origin = [] } :: extras, env)
1386+
(Row_eq { r1 = row_of_var v r.id; r2 = lub; origin } :: extras, env)
13841387
else if stored then (extras, env)
13851388
else (Rows_constr { r = [ r ]; constr; origin } :: extras, env)
13861389
| _ ->
@@ -1614,7 +1617,7 @@ let%debug5_sexp rec unify_row ~stage origin (eq : t * t) (env : environment) :
16141617
id;
16151618
};
16161619
r2;
1617-
origin = [];
1620+
origin;
16181621
};
16191622
],
16201623
env )
@@ -1632,9 +1635,9 @@ let%debug5_sexp rec unify_row ~stage origin (eq : t * t) (env : environment) :
16321635
ineqs := Row_ineq { cur = row_of_var cur value.id; subr = r2; origin } :: !ineqs);
16331636
List.iter subr ~f:(fun subr ->
16341637
ineqs :=
1635-
Row_ineq { subr = row_of_var subr value.id; cur = r2; origin = [] } :: !ineqs);
1638+
Row_ineq { subr = row_of_var subr value.id; cur = r2; origin } :: !ineqs);
16361639
Option.iter lub ~f:(fun lub ->
1637-
ineqs := Row_ineq { cur = lub; subr = r2; origin = [] } :: !ineqs);
1640+
ineqs := Row_ineq { cur = lub; subr = r2; origin } :: !ineqs);
16381641
let extras, env = apply_row_constraint ~depth:0 stage origin value constr env in
16391642
ineqs := extras @ !ineqs;
16401643
env)
@@ -1649,7 +1652,7 @@ let%debug5_sexp rec unify_row ~stage origin (eq : t * t) (env : environment) :
16491652
raise @@ Shape_error ("Mismatching number of axes", [ Row_mismatch [ r1; r2 ] ])
16501653
| Ok eqs ->
16511654
List.fold ~init:([], env)
1652-
~f:(fun acc (d1, d2) -> solve acc (Dim_eq { d1; d2; origin = [] }))
1655+
~f:(fun acc (d1, d2) -> solve acc (Dim_eq { d1; d2; origin }))
16531656
eqs)
16541657

16551658
let%track5_sexp solve_dim_ineq ~(stage : stage) origin ~(cur : dim) ~(subr : dim)
@@ -1679,14 +1682,14 @@ let%track5_sexp solve_dim_ineq ~(stage : stage) origin ~(cur : dim) ~(subr : dim
16791682
("dimension comparison for axis: different labels", [ Dim_mismatch [ cur; subr ] ])
16801683
| Dim { d = d1; _ }, Dim { d = d2; _ } when d1 = d2 -> ([], env)
16811684
| _, Dim { d = 1; _ } -> ([], env)
1682-
| (Dim { d = 1; _ } as cur), _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin = [] } ], env)
1683-
| Conv_input _, _ | _, Conv_input _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin = [] } ], env)
1685+
| (Dim { d = 1; _ } as cur), _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin } ], env)
1686+
| Conv_input _, _ | _, Conv_input _ -> ([ Dim_eq { d1 = subr; d2 = cur; origin } ], env)
16841687
| Var cur_v, Var subr_v -> (
16851688
match (Map.find env.dim_env cur_v, Map.find env.dim_env subr_v) with
16861689
| Some (Bounds_dim { cur = cur1; _ }), _ when List.mem ~equal:equal_dim_var cur1 subr_v ->
1687-
([ Dim_eq { d1 = cur; d2 = subr; origin = [] } ], env)
1690+
([ Dim_eq { d1 = cur; d2 = subr; origin } ], env)
16881691
| _, Some (Bounds_dim { subr = subr2; _ }) when List.mem ~equal:equal_dim_var subr2 cur_v ->
1689-
([ Dim_eq { d1 = cur; d2 = subr; origin = [] } ], env)
1692+
([ Dim_eq { d1 = cur; d2 = subr; origin } ], env)
16901693
| None, None ->
16911694
( [],
16921695
{
@@ -1915,7 +1918,7 @@ let%debug5_sexp solve_row_ineq ~(stage : stage) origin ~(cur : t) ~(subr : t) (e
19151918
when is_stage6_up stage ->
19161919
( Row_ineq { cur; subr; origin }
19171920
:: Row_eq
1918-
{ r1 = row_of_var v id; r2 = { dims = []; bcast = Broadcastable; id }; origin = [] }
1921+
{ r1 = row_of_var v id; r2 = { dims = []; bcast = Broadcastable; id }; origin }
19191922
:: ineqs,
19201923
env )
19211924
| cur, subr when equal_row cur subr -> ([], env)
@@ -2086,8 +2089,8 @@ let%debug5_sexp solve_row_ineq ~(stage : stage) origin ~(cur : t) ~(subr : t) (e
20862089
(* We don't need to add any dimension inequalities, because they'll be captured by the extra
20872090
row inequalities. *)
20882091
( [
2089-
Row_eq { r1 = cur; r2 = template; origin = [] };
2090-
Row_ineq { cur = template; subr; origin = [] };
2092+
Row_eq { r1 = cur; r2 = template; origin };
2093+
Row_ineq { cur = template; subr; origin };
20912094
],
20922095
env )
20932096
| { bcast = Broadcastable; _ }, _ when cur_dims_l + cur_beg_dims_l < subr_dims_l + subr_beg_dims_l
@@ -2293,7 +2296,7 @@ and eliminate_row_constraint ~depth stage origin ~terminal ~(lub : row option) (
22932296
( no_further_axes
22942297
:: List.map vs ~f:(fun v ->
22952298
let d2 = get_dim ~d:1 () in
2296-
Dim_eq { d1 = Var v; d2; origin = [] }),
2299+
Dim_eq { d1 = Var v; d2; origin }),
22972300
env )
22982301
| Num_elems d, [], None when d <> 1 && is_stage3_up stage ->
22992302
let dim = get_dim ~d () in
@@ -2373,7 +2376,7 @@ and eliminate_row_constraint ~depth stage origin ~terminal ~(lub : row option) (
23732376
let _denom : int = denom in
23742377
keep_constr ()
23752378
| _ -> keep_constr ())
2376-
| Exact dims -> ([ Row_eq { r1; r2 = { dims; bcast = Broadcastable; id }; origin = [] } ], env)
2379+
| Exact dims -> ([ Row_eq { r1; r2 = { dims; bcast = Broadcastable; id }; origin } ], env)
23772380
| Unconstrained -> ([], env))
23782381

23792382
let%track5_sexp close_row_terminal ~(stage : stage) origin (env : environment)

0 commit comments

Comments
 (0)